Commit 9812b48
authored
[ENH] Refactor N-BEATS blocks to separate KAN logic (#2012)
#### Reference Issues/PRs
Fixes #2011
---
#### What does this implement/fix? Explain your changes.
This PR refactors the N-BEATS architecture to explicitly separate
standard MLP blocks from KAN-based blocks, removing the previous
`if/else` dependency on the `use_kan` flag.
* **Split Classes:** Created distinct classes for all variants to ensure
clean separation of concerns:
* `NBEATSBlock` / `NBEATSBlockKAN`
* `NBEATSSeasonalBlock` / `NBEATSSeasonalBlockKAN`
* `NBEATSTrendBlock` / `NBEATSTrendBlockKAN`
* **Mixin Pattern:** Implemented `SeasonalMixin` and `TrendMixin` to
handle the shared mathematical logic (Fourier series and Polynomial
generation) between standard and KAN blocks.
* Unpacked the `kan_params` dictionary into explicit `__init__`
arguments (e.g., `num`, `k`, `noise_scale`) in the KAN classes.
* Updated `NBeatsKAN` to correctly initialize these new specific block
classes based on the `stack_types` argument.
---
#### What should a reviewer concentrate their feedback on?
* Please verify that `SeasonalMixin` and `TrendMixin` correctly capture
the shared logic without unwanted side effects.
* Check if the use of `**kan_kwargs` in the subclasses correctly passes
parameters up to `NBEATSBlockKAN`.
* I restored `F.relu` in `NBEATSGenericBlockKAN` to ensure consistency
with the standard generic block.
---
#### Did you add any tests for the change?
Yes, I added `test_nbeats_kan_integration` which covers the following:
1. Verifies Init -> Train -> Save -> Load -> Predict flows working
together.
2. Asserts that KAN-specific hyperparameters (e.g., `num=10`) are
correctly saved to `hparams` and restored upon checkpoint loading.
3. Asserts that the model actually initializes the correct class types
(e.g., `NBEATSTrendBlockKAN`) when `stack_types` are requested.1 parent dbfd38b commit 9812b48
File tree
4 files changed
+646
-237
lines changed- pytorch_forecasting
- layers/_nbeats
- models/nbeats
- tests/test_models
4 files changed
+646
-237
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
4 | 4 | | |
5 | 5 | | |
6 | 6 | | |
| 7 | + | |
7 | 8 | | |
| 9 | + | |
8 | 10 | | |
| 11 | + | |
9 | 12 | | |
| 13 | + | |
10 | 14 | | |
11 | 15 | | |
12 | 16 | | |
13 | 17 | | |
14 | 18 | | |
15 | 19 | | |
16 | 20 | | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
17 | 25 | | |
0 commit comments