55from typing import Annotated
66import torch
77from darts import TimeSeries
8- from darts .models import TFTModel
8+ from darts .models import TFTModel , RNNModel
99from zenml import step
1010from zenml .logger import get_logger
1111from materializers .tft_materializer import (
@@ -30,6 +30,8 @@ def train_model(
3030 add_relative_index : bool = True ,
3131 enable_progress_bar : bool = False ,
3232 enable_model_summary : bool = False ,
33+ learning_rate : float = 1e-3 ,
34+ weight_decay : float = 1e-5 ,
3335) -> Annotated [TFTModel , "trained_model" ]:
3436 """Train a TFT forecasting model.
3537
@@ -47,6 +49,8 @@ def train_model(
4749 add_relative_index: Whether to add relative index
4850 enable_progress_bar: Whether to show progress bar
4951 enable_model_summary: Whether to show model summary
52+ learning_rate: Learning rate for optimizer
53+ weight_decay: Weight decay for regularization
5054
5155 Returns:
5256 Trained TFT model
@@ -63,17 +67,46 @@ def train_model(
6367 "n_epochs" : n_epochs ,
6468 "random_state" : random_state ,
6569 "add_relative_index" : add_relative_index ,
70+ "optimizer_kwargs" : {
71+ "lr" : learning_rate ,
72+ "weight_decay" : weight_decay ,
73+ },
6674 "pl_trainer_kwargs" : {
6775 "enable_progress_bar" : enable_progress_bar ,
6876 "enable_model_summary" : enable_model_summary ,
6977 "precision" : "32-true" , # Use 32-bit precision for better hardware compatibility
78+ "gradient_clip_val" : 1.0 , # Standard gradient clipping
79+ "gradient_clip_algorithm" : "norm" , # Clip by norm
80+ "detect_anomaly" : True , # Detect NaN/inf in loss
81+ "max_epochs" : n_epochs ,
82+ "check_val_every_n_epoch" : 1 , # Validate every epoch
83+ "accelerator" : "cpu" , # Force CPU to avoid MPS issues
7084 },
7185 }
7286
7387 logger .info (f"Training TFT model with params: { model_params } " )
7488
7589 # Initialize TFT model
7690 model = TFTModel (** model_params )
91+
92+ # Initialize model weights with Xavier/Glorot initialization for stability
93+ def init_weights (m ):
94+ if isinstance (m , torch .nn .Linear ):
95+ torch .nn .init .xavier_uniform_ (m .weight )
96+ if m .bias is not None :
97+ torch .nn .init .zeros_ (m .bias )
98+ elif isinstance (m , torch .nn .LSTM ):
99+ for name , param in m .named_parameters ():
100+ if "weight" in name :
101+ torch .nn .init .xavier_uniform_ (param )
102+ elif "bias" in name :
103+ torch .nn .init .zeros_ (param )
104+
105+ # Apply weight initialization
106+ if hasattr (model , "model" ) and model .model is not None :
107+ model .model .apply (init_weights )
108+ logger .info ("Applied Xavier weight initialization to model" )
109+
77110 logger .info (f"Starting TFT training with { len (train_series )} data points" )
78111
79112 # Train the TFT model
0 commit comments