Skip to content

Commit 9812b48

Browse files
authored
[ENH] Refactor N-BEATS blocks to separate KAN logic (#2012)
#### Reference Issues/PRs Fixes #2011 --- #### What does this implement/fix? Explain your changes. This PR refactors the N-BEATS architecture to explicitly separate standard MLP blocks from KAN-based blocks, removing the previous `if/else` dependency on the `use_kan` flag. * **Split Classes:** Created distinct classes for all variants to ensure clean separation of concerns: * `NBEATSBlock` / `NBEATSBlockKAN` * `NBEATSSeasonalBlock` / `NBEATSSeasonalBlockKAN` * `NBEATSTrendBlock` / `NBEATSTrendBlockKAN` * **Mixin Pattern:** Implemented `SeasonalMixin` and `TrendMixin` to handle the shared mathematical logic (Fourier series and Polynomial generation) between standard and KAN blocks. * Unpacked the `kan_params` dictionary into explicit `__init__` arguments (e.g., `num`, `k`, `noise_scale`) in the KAN classes. * Updated `NBeatsKAN` to correctly initialize these new specific block classes based on the `stack_types` argument. --- #### What should a reviewer concentrate their feedback on? * Please verify that `SeasonalMixin` and `TrendMixin` correctly capture the shared logic without unwanted side effects. * Check if the use of `**kan_kwargs` in the subclasses correctly passes parameters up to `NBEATSBlockKAN`. * I restored `F.relu` in `NBEATSGenericBlockKAN` to ensure consistency with the standard generic block. --- #### Did you add any tests for the change? Yes, I added `test_nbeats_kan_integration` which covers the following: 1. Verifies Init -> Train -> Save -> Load -> Predict flows working together. 2. Asserts that KAN-specific hyperparameters (e.g., `num=10`) are correctly saved to `hparams` and restored upon checkpoint loading. 3. Asserts that the model actually initializes the correct class types (e.g., `NBEATSTrendBlockKAN`) when `stack_types` are requested.
1 parent dbfd38b commit 9812b48

File tree

4 files changed

+646
-237
lines changed

4 files changed

+646
-237
lines changed

pytorch_forecasting/layers/_nbeats/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,22 @@
44

55
from pytorch_forecasting.layers._nbeats._blocks import (
66
NBEATSBlock,
7+
NBEATSBlockKAN,
78
NBEATSGenericBlock,
9+
NBEATSGenericBlockKAN,
810
NBEATSSeasonalBlock,
11+
NBEATSSeasonalBlockKAN,
912
NBEATSTrendBlock,
13+
NBEATSTrendBlockKAN,
1014
)
1115

1216
__all__ = [
1317
"NBEATSBlock",
1418
"NBEATSGenericBlock",
1519
"NBEATSSeasonalBlock",
1620
"NBEATSTrendBlock",
21+
"NBEATSBlockKAN",
22+
"NBEATSGenericBlockKAN",
23+
"NBEATSSeasonalBlockKAN",
24+
"NBEATSTrendBlockKAN",
1725
]

0 commit comments

Comments
 (0)