Skip to content

Commit 544d2dc

Browse files
authored
Merge pull request #161 from wilsonrljr/fix/get_iterable_list
fix bilinear basis for miso systems
2 parents bc87dd5 + 46bed4d commit 544d2dc

File tree

2 files changed

+37
-19
lines changed

2 files changed

+37
-19
lines changed

sysidentpy/basis_function/_bilinear.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import warnings
44

5-
from itertools import combinations_with_replacement
5+
from itertools import combinations_with_replacement, chain
66
from typing import Optional
77
import numpy as np
88

@@ -101,19 +101,29 @@ def fit(
101101
"In this case, you have a linear polynomial model.",
102102
stacklevel=2,
103103
)
104-
else:
105-
ny = self.get_max_ylag(ylag)
106-
nx = self.get_max_xlag(xlag)
107-
combination_ylag = list(
108-
combinations_with_replacement(list(range(1, ny + 1)), self.degree)
109-
)
110-
combination_xlag = list(
104+
105+
ny = self.get_max_ylag(ylag)
106+
combination_ylag = list(
107+
combinations_with_replacement(list(range(1, ny + 1)), self.degree)
108+
)
109+
if isinstance(xlag, int):
110+
xlag = [xlag]
111+
112+
combination_xlag = []
113+
ni = 0
114+
for lag in xlag:
115+
nx = self.get_max_xlag(lag)
116+
combination_lag = list(
111117
combinations_with_replacement(
112-
list(range(ny + 1, nx + ny + 1)), self.degree
118+
list(range(ny + 1 + ni, nx + ny + 1 + ni)), self.degree
113119
)
114120
)
115-
combinations_xy = combination_xlag + combination_ylag
116-
combination_list = list(set(combination_list) - set(combinations_xy))
121+
combination_xlag.append(combination_lag)
122+
ni += nx
123+
124+
combination_xlag = list(chain.from_iterable(combination_xlag))
125+
combinations_xy = combination_xlag + combination_ylag
126+
combination_list = list(set(combination_list) - set(combinations_xy))
117127

118128
if predefined_regressors is not None:
119129
combination_list = [

sysidentpy/basis_function/basis_function_base.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,21 +56,29 @@ def get_max_ylag(self, ylag: int = 1):
5656
return ny
5757

5858
def get_max_xlag(self, xlag: int = 1):
59-
"""Get maximum xlag.
59+
"""Get maximum value from various xlag structures.
6060
6161
Parameters
6262
----------
63-
xlag : ndarray of int
64-
The range of lags according to user definition.
63+
xlag : int, list of int, or nested list of int
64+
Input that can be a single integer, a list, or a nested list.
6565
6666
Returns
6767
-------
68-
nx : list
69-
Maximum value of xlag.
70-
68+
int
69+
Maximum value found.
7170
"""
72-
nx = np.max(list(chain.from_iterable([[np.array(xlag, dtype=object)]])))
73-
return nx
71+
if isinstance(xlag, int): # Case 1: Single integer
72+
return xlag
73+
74+
if isinstance(xlag, list):
75+
# Case 2: Flat list of integers
76+
if all(isinstance(i, int) for i in xlag):
77+
return max(xlag)
78+
# Case 3: Nested list
79+
return max(chain.from_iterable(xlag))
80+
81+
raise ValueError("Unsupported data type for xlag")
7482

7583
def get_iterable_list(
7684
self, ylag: int = 1, xlag: int = 1, model_type: str = "NARMAX"

0 commit comments

Comments
 (0)