44
55from typing import Annotated
66import pandas as pd
7+ import numpy as np
78from darts import TimeSeries
89from zenml import step , get_step_context , log_metadata
910from zenml .logger import get_logger
1011from zenml .client import Client
12+ from utils .prediction import iterative_predict
1113
1214logger = get_logger (__name__ )
1315
1416
1517@step
1618def batch_inference_predict (
17- series : TimeSeries ,
19+ df : pd .DataFrame ,
20+ datetime_col : str = "ds" ,
21+ target_col : str = "y" ,
22+ freq : str = "D" ,
1823 horizon : int = 14 ,
1924) -> Annotated [pd .DataFrame , "predictions" ]:
2025 """
2126 Perform batch inference using the trained model from Model Control Plane.
2227
2328 Args:
24- series: Time series data for forecasting
29+ df: Raw DataFrame with datetime and target columns
30+ datetime_col: Name of datetime column
31+ target_col: Name of target column
32+ freq: Frequency string for time series
2533 horizon: Number of time steps to forecast
2634
2735 Returns:
@@ -30,6 +38,21 @@ def batch_inference_predict(
3038 logger .info (f"Performing batch inference with horizon: { horizon } " )
3139
3240 try :
41+ # Convert DataFrame to TimeSeries
42+ logger .info ("Converting DataFrame to TimeSeries" )
43+ series = TimeSeries .from_dataframe (
44+ df , time_col = datetime_col , value_cols = target_col , freq = freq
45+ )
46+
47+ # Cast to float32 for consistency with training data
48+ logger .info ("Converting TimeSeries to float32 for consistency" )
49+ series = series .astype (np .float32 )
50+
51+ logger .info (f"Created TimeSeries with { len (series )} points" )
52+ logger .info (
53+ f"Series range: { series .start_time ()} to { series .end_time ()} "
54+ )
55+
3356 # Get the model from Model Control Plane
3457 context = get_step_context ()
3558 if not context .model :
@@ -72,41 +95,48 @@ def batch_inference_predict(
7295 f"Loaded model from Model Control Plane: { type (trained_model ).__name__ } "
7396 )
7497
75- # Generate predictions using improved multi-step approach (same as evaluation)
76- logger .info (
77- f"Using iterative multi-step prediction for horizon={ horizon } "
78- )
79-
80- # Use multiple prediction steps for better long-term accuracy
81- predictions_list = []
82- context_series = series
83-
84- # Predict in chunks of output_chunk_length (14 days)
85- remaining_steps = horizon
86- while remaining_steps > 0 :
87- chunk_size = min (
88- 14 , remaining_steps
89- ) # Model's output_chunk_length
90- chunk_pred = trained_model .predict (
91- n = chunk_size , series = context_series
98+ # Load the fitted scaler artifact
99+ fitted_scaler = None
100+ try :
101+ scaler_artifact = context .model .get_artifact ("fitted_scaler" )
102+ if scaler_artifact is None :
103+ raise ValueError (
104+ "fitted_scaler artifact not found in model version"
105+ )
106+ fitted_scaler = scaler_artifact .load ()
107+ logger .info ("Loaded fitted scaler artifact from model version" )
108+
109+ # Apply scaling to the input series
110+ logger .info ("Applying scaling to input series for inference" )
111+ series = fitted_scaler .transform (series )
112+ logger .info ("Scaling applied successfully" )
113+ except Exception as scaler_error :
114+ logger .error (f"Failed to load or apply scaler: { scaler_error } " )
115+ logger .warning (
116+ "Proceeding without scaling - predictions may be incorrect!"
92117 )
93- predictions_list . append ( chunk_pred )
118+ # Continue without scaling for backward compatibility
94119
95- # Extend context with the prediction for next iteration
96- context_series = context_series .concatenate (chunk_pred )
97- remaining_steps -= chunk_size
120+ # Generate predictions using improved multi-step approach
121+ predictions = iterative_predict (trained_model , series , horizon )
98122
99- # Combine all predictions
100- if len (predictions_list ) == 1 :
101- predictions = predictions_list [0 ]
123+ # Inverse transform predictions back to original scale
124+ if fitted_scaler is not None :
125+ try :
126+ logger .info (
127+ "Inverse transforming predictions back to original scale"
128+ )
129+ predictions = fitted_scaler .inverse_transform (predictions )
130+ logger .info ("Inverse transformation applied successfully" )
131+ except Exception as inverse_error :
132+ logger .error (
133+ f"Failed to inverse transform predictions: { inverse_error } "
134+ )
135+ logger .warning ("Predictions remain in scaled format!" )
102136 else :
103- predictions = predictions_list [0 ]
104- for pred_chunk in predictions_list [1 :]:
105- predictions = predictions .concatenate (pred_chunk )
106-
107- logger .info (
108- f"Generated { len (predictions )} predictions using multi-step approach"
109- )
137+ logger .warning (
138+ "No scaler available - predictions remain in original format"
139+ )
110140
111141 # Convert to DataFrame
112142 pred_df = predictions .pd_dataframe ().reset_index ()
0 commit comments