Skip to content

Commit 32ba967

Browse files
committed
Update training and inference configurations
1 parent 768bda2 commit 32ba967

File tree

8 files changed

+158
-43
lines changed

8 files changed

+158
-43
lines changed

floracast/configs/training.yaml

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,20 +41,22 @@ steps:
4141

4242
train_model:
4343
parameters:
44-
input_chunk_length: 90
45-
output_chunk_length: 14
46-
hidden_size: 256
47-
lstm_layers: 2
48-
num_attention_heads: 8
49-
dropout: 0.15
50-
batch_size: 16
51-
n_epochs: 100
44+
input_chunk_length: 14 # Longer input for better pattern recognition
45+
output_chunk_length: 14 # 2-week forecasting horizon for impressive demo
46+
hidden_size: 16 # Smaller hidden size to prevent instability
47+
lstm_layers: 1 # Single layer
48+
num_attention_heads: 1 # Single head to prevent complexity issues
49+
dropout: 0.0 # No dropout to eliminate regularization issues
50+
batch_size: 4 # Small batch size that works with data
51+
n_epochs: 5 # Few epochs
5252
random_state: 42
53-
add_relative_index: true
53+
add_relative_index: true # Required for TFT - generates future covariates
5454
enable_progress_bar: true
5555
enable_model_summary: true
56+
learning_rate: 0.001 # Standard learning rate that works
57+
weight_decay: 0.0 # No weight decay to eliminate regularization issues
5658

5759
evaluate:
5860
parameters:
59-
horizon: 14
61+
horizon: 14 # Match updated output_chunk_length - 2 weeks
6062
metric: "smape"

floracast/materializers/tft_materializer.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,22 +32,30 @@ def load(self, data_type: Type[Any]) -> Any:
3232
"""Load a TFT model using enhanced reconstruction strategy."""
3333
# using top-level TFTModel import
3434

35-
# Check what save strategies were used
36-
strategy_info = self._load_strategy_info()
35+
# Set PyTorch default dtype to float32 for consistent precision
36+
original_dtype = torch.get_default_dtype()
37+
torch.set_default_dtype(torch.float32)
3738

38-
# Try enhanced reconstruction if PyTorch state was saved
39-
if strategy_info.get("pytorch_model_saved", False):
39+
try:
40+
# Check what save strategies were used
41+
strategy_info = self._load_strategy_info()
42+
43+
# Try enhanced reconstruction if PyTorch state was saved
44+
if strategy_info.get("pytorch_model_saved", False):
45+
try:
46+
return self._load_with_pytorch_state()
47+
except Exception as e:
48+
logger.warning(f"Enhanced reconstruction failed: {e}")
49+
50+
# Fallback to pickle loading
4051
try:
41-
return self._load_with_pytorch_state()
52+
return self._load_pickle_format()
4253
except Exception as e:
43-
logger.warning(f"Enhanced reconstruction failed: {e}")
44-
45-
# Fallback to pickle loading
46-
try:
47-
return self._load_pickle_format()
48-
except Exception as e:
49-
logger.error(f"All loading strategies failed: {e}")
50-
raise
54+
logger.error(f"All loading strategies failed: {e}")
55+
raise
56+
finally:
57+
# Restore original PyTorch dtype
58+
torch.set_default_dtype(original_dtype)
5159

5260
def _load_native_format(self) -> Any:
5361
"""Load TFT model using native Darts save format."""
@@ -206,10 +214,10 @@ def _load_pickle_format(self) -> Any:
206214
with fileio.open(pickle_path, "rb") as f:
207215
model = pickle.load(f)
208216

209-
logger.warning(
210-
"Loaded from pickle - internal PyTorch model may be None"
211-
)
212-
return model
217+
logger.warning(
218+
"Loaded from pickle - internal PyTorch model may be None"
219+
)
220+
return model
213221

214222
def save(self, data: Any) -> None:
215223
"""Save TFT model using enhanced strategy that preserves internal PyTorch model."""

floracast/materializers/timeseries_materializer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from typing import Any, Dict, Type
1414

1515
import pandas as pd
16+
import numpy as np
1617
import matplotlib
1718

1819
# Use a non-interactive backend for headless environments
@@ -102,6 +103,12 @@ def load(self, data_type: Type[Any]) -> Any:
102103
df, time_col=time_col, value_cols=value_cols, freq=freq
103104
)
104105

106+
# Convert to float32 for hardware compatibility (MPS, mixed precision training)
107+
logger.debug(
108+
"Converting TimeSeries to float32 for hardware compatibility"
109+
)
110+
ts = ts.astype(np.float32)
111+
105112
# Restore static covariates if present
106113
if fileio.exists(static_covariates_path):
107114
with fileio.open(static_covariates_path, "r") as f:

floracast/pipelines/batch_inference_pipeline.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,8 @@ def batch_inference_pipeline() -> None:
1616
"""
1717
Batch inference pipeline that loads model from Model Control Plane and generates predictions.
1818
"""
19-
logger.info("Starting FloraCast batch inference pipeline")
20-
2119
# Step 1: Ingest data (simulate real-time data sources)
2220
raw_data = ingest_data(infer=True)
2321

2422
# Step 2: Generate predictions using model from MCP (with scaling handled internally)
2523
batch_inference_predict(df=raw_data)
26-
27-
logger.info("Batch inference completed. Returning predictions DataFrame.")

floracast/steps/batch_infer.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,34 @@
22
Batch inference step for FloraCast using ZenML Model Control Plane.
33
"""
44

5-
from typing import Annotated
5+
from typing import Annotated, Tuple
66
import pandas as pd
77
import numpy as np
88
from darts import TimeSeries
99
from zenml import step, get_step_context, log_metadata
1010
from zenml.logger import get_logger
1111
from zenml.client import Client
1212
from utils.prediction import iterative_predict
13+
from materializers.timeseries_materializer import DartsTimeSeriesMaterializer
1314

1415
logger = get_logger(__name__)
1516

1617

17-
@step
18+
@step(
19+
output_materializers={
20+
"prediction_series": DartsTimeSeriesMaterializer,
21+
}
22+
)
1823
def batch_inference_predict(
1924
df: pd.DataFrame,
2025
datetime_col: str = "ds",
2126
target_col: str = "y",
2227
freq: str = "D",
2328
horizon: int = 14,
24-
) -> Annotated[pd.DataFrame, "predictions"]:
29+
) -> Tuple[
30+
Annotated[pd.DataFrame, "predictions"],
31+
Annotated[TimeSeries, "prediction_series"],
32+
]:
2533
"""
2634
Perform batch inference using the trained model from Model Control Plane.
2735
@@ -34,19 +42,16 @@ def batch_inference_predict(
3442
3543
Returns:
3644
DataFrame containing forecast results with columns ['ds', 'yhat']
45+
TimeSeries containing the forecast results
3746
"""
3847
logger.info(f"Performing batch inference with horizon: {horizon}")
3948

4049
try:
41-
# Convert DataFrame to TimeSeries
50+
# Convert DataFrame to TimeSeries and cast to float32 for consistency
4251
logger.info("Converting DataFrame to TimeSeries")
4352
series = TimeSeries.from_dataframe(
4453
df, time_col=datetime_col, value_cols=target_col, freq=freq
45-
)
46-
47-
# Cast to float32 for consistency with training data
48-
logger.info("Converting TimeSeries to float32 for consistency")
49-
series = series.astype(np.float32)
54+
).astype(np.float32)
5055

5156
logger.info(f"Created TimeSeries with {len(series)} points")
5257
logger.info(
@@ -167,7 +172,7 @@ def batch_inference_predict(
167172
}
168173
)
169174

170-
return pred_df
175+
return pred_df, predictions
171176

172177
except Exception as e:
173178
logger.error(f"Batch inference failed: {str(e)}")

floracast/steps/evaluate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def create_evaluation_visualization(
236236
return HTMLString(error_html)
237237

238238

239-
@step
239+
@step(enable_cache=False)
240240
def evaluate(
241241
model: object,
242242
train_series: TimeSeries,

floracast/steps/preprocess.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,70 @@ def preprocess_data(
7575
train_series = train_series.astype(np.float32)
7676
val_series = val_series.astype(np.float32)
7777

78+
# Check for NaN/inf values in scaled data
79+
train_values = train_series.pd_dataframe().values
80+
val_values = val_series.pd_dataframe().values
81+
82+
train_nan_count = np.isnan(train_values).sum()
83+
train_inf_count = np.isinf(train_values).sum()
84+
val_nan_count = np.isnan(val_values).sum()
85+
val_inf_count = np.isinf(val_values).sum()
86+
87+
logger.info(
88+
f"Data quality check - Train NaN: {train_nan_count}, Train Inf: {train_inf_count}"
89+
)
90+
logger.info(
91+
f"Data quality check - Val NaN: {val_nan_count}, Val Inf: {val_inf_count}"
92+
)
93+
94+
# Check for extreme values that could cause numerical instability
95+
train_min, train_max = train_values.min(), train_values.max()
96+
val_min, val_max = val_values.min(), val_values.max()
97+
logger.info(
98+
f"Value ranges - Train: [{train_min:.6f}, {train_max:.6f}], Val: [{val_min:.6f}, {val_max:.6f}]"
99+
)
100+
101+
# Flag potentially problematic values
102+
needs_cleaning = (
103+
train_nan_count > 0
104+
or train_inf_count > 0
105+
or val_nan_count > 0
106+
or val_inf_count > 0
107+
or abs(train_min) > 1e6
108+
or abs(train_max) > 1e6
109+
or abs(val_min) > 1e6
110+
or abs(val_max) > 1e6
111+
)
112+
113+
if needs_cleaning:
114+
logger.warning(
115+
"Found potentially problematic values in scaled data - cleaning..."
116+
)
117+
118+
# Replace NaN/Inf and clip extreme values
119+
train_df = train_series.pd_dataframe()
120+
val_df = val_series.pd_dataframe()
121+
122+
# Handle NaN/Inf
123+
train_df = train_df.replace([np.inf, -np.inf], np.nan)
124+
val_df = val_df.replace([np.inf, -np.inf], np.nan)
125+
126+
train_df = train_df.fillna(0.0)
127+
val_df = val_df.fillna(0.0)
128+
129+
# Clip extreme values to reasonable range
130+
train_df = train_df.clip(-10.0, 10.0)
131+
val_df = val_df.clip(-10.0, 10.0)
132+
133+
train_series = TimeSeries.from_dataframe(train_df).astype(np.float32)
134+
val_series = TimeSeries.from_dataframe(val_df).astype(np.float32)
135+
136+
logger.info(
137+
"Cleaned data - replaced NaN/Inf and clipped to [-10, 10] range"
138+
)
139+
else:
140+
logger.info("Data quality check passed - no problematic values found")
141+
78142
# Return fitted scaler as artifact for inference use
79143
logger.info("Returning fitted scaler as artifact for inference use")
80144

floracast/steps/train.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import Annotated
66
import torch
77
from darts import TimeSeries
8-
from darts.models import TFTModel
8+
from darts.models import TFTModel, RNNModel
99
from zenml import step
1010
from zenml.logger import get_logger
1111
from materializers.tft_materializer import (
@@ -30,6 +30,8 @@ def train_model(
3030
add_relative_index: bool = True,
3131
enable_progress_bar: bool = False,
3232
enable_model_summary: bool = False,
33+
learning_rate: float = 1e-3,
34+
weight_decay: float = 1e-5,
3335
) -> Annotated[TFTModel, "trained_model"]:
3436
"""Train a TFT forecasting model.
3537
@@ -47,6 +49,8 @@ def train_model(
4749
add_relative_index: Whether to add relative index
4850
enable_progress_bar: Whether to show progress bar
4951
enable_model_summary: Whether to show model summary
52+
learning_rate: Learning rate for optimizer
53+
weight_decay: Weight decay for regularization
5054
5155
Returns:
5256
Trained TFT model
@@ -63,17 +67,46 @@ def train_model(
6367
"n_epochs": n_epochs,
6468
"random_state": random_state,
6569
"add_relative_index": add_relative_index,
70+
"optimizer_kwargs": {
71+
"lr": learning_rate,
72+
"weight_decay": weight_decay,
73+
},
6674
"pl_trainer_kwargs": {
6775
"enable_progress_bar": enable_progress_bar,
6876
"enable_model_summary": enable_model_summary,
6977
"precision": "32-true", # Use 32-bit precision for better hardware compatibility
78+
"gradient_clip_val": 1.0, # Standard gradient clipping
79+
"gradient_clip_algorithm": "norm", # Clip by norm
80+
"detect_anomaly": True, # Detect NaN/inf in loss
81+
"max_epochs": n_epochs,
82+
"check_val_every_n_epoch": 1, # Validate every epoch
83+
"accelerator": "cpu", # Force CPU to avoid MPS issues
7084
},
7185
}
7286

7387
logger.info(f"Training TFT model with params: {model_params}")
7488

7589
# Initialize TFT model
7690
model = TFTModel(**model_params)
91+
92+
# Initialize model weights with Xavier/Glorot initialization for stability
93+
def init_weights(m):
94+
if isinstance(m, torch.nn.Linear):
95+
torch.nn.init.xavier_uniform_(m.weight)
96+
if m.bias is not None:
97+
torch.nn.init.zeros_(m.bias)
98+
elif isinstance(m, torch.nn.LSTM):
99+
for name, param in m.named_parameters():
100+
if "weight" in name:
101+
torch.nn.init.xavier_uniform_(param)
102+
elif "bias" in name:
103+
torch.nn.init.zeros_(param)
104+
105+
# Apply weight initialization
106+
if hasattr(model, "model") and model.model is not None:
107+
model.model.apply(init_weights)
108+
logger.info("Applied Xavier weight initialization to model")
109+
77110
logger.info(f"Starting TFT training with {len(train_series)} data points")
78111

79112
# Train the TFT model

0 commit comments

Comments
 (0)