Skip to content

Commit 1f3f5a7

Browse files
committed
add apriori parameter fixing functionality with unit test
1 parent 8f1e113 commit 1f3f5a7

File tree

9 files changed

+415
-96
lines changed

9 files changed

+415
-96
lines changed

src/xinv/core/tools.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
import numpy as np
55
import xarray as xr
6+
from xinv.core.attrs import find_xinv_unk_coord,get_state,get_type,xinv_tp,xinv_st,find_xinv_group_coords
7+
from xinv.core.logging import xinvlogger
68

79
def find_ilocs(dsneq,dim,elements):
810
idxname=f"_{dim}_idx"
@@ -27,3 +29,80 @@ def find_overlap_coords(coord1,coord2):
2729
uniq2=np.setdiff1d(coord2, coord1, assume_unique=False)
2830

2931
return uniq1, intersect, uniq2
32+
33+
34+
def find_unk_idx(dsneq,sort=True,**kwargs):
35+
"""
36+
Find the indices of a set of unknown parameters in the unknown vector of a normal equation system
37+
Parameters:
38+
-----------
39+
dsneq: xarray.Dataset
40+
Dataset containing the normal equation system
41+
sort: bool, optional
42+
If True, the output indices are sorted in ascending order. The default is True.
43+
kwargs: dict
44+
keyword arguments with dimension names as keys and the elements to find as values
45+
e.g. poly=[0,1],harmonics_seasonal=[1,2]
46+
Returns:
47+
--------
48+
idxfound: np.ndarray or None
49+
Indices of the found unknown parameters in the unknown vector, or None if no parameters were found
50+
idxremaining: np.ndarray or None
51+
52+
Indices of the unknown parameters that are complementary to the ones found
53+
idxnotfound: np.ndarray or None
54+
Indices of the requested unknown parameters that were not found in the system
55+
56+
"""
57+
58+
xunk_co=find_xinv_unk_coord(dsneq)
59+
60+
unkdim=xunk_co.dims[0]
61+
62+
group_id_co=None
63+
group_seq_co=None
64+
65+
notfound=[]
66+
found=[]
67+
remaining=[]
68+
for coname,searchparams in kwargs.items():
69+
co_search=dsneq[coname]
70+
if get_type(co_search) != xinv_tp.unk_co:
71+
raise ValueError(f"Missing xinv_type: supplied coordinate name {coname} has no valid relation with unknown coordinate {xunk_co.name}")
72+
dimname=co_search.dims[0]
73+
if searchparams is not type(xr.DataArray):
74+
#turn into DataArray
75+
searchparams=xr.DataArray(searchparams,dims=dimname)
76+
#find unique and overlapping coordinates over the unknown dimension
77+
notfnd,fnd,remng=find_overlap_coords(searchparams,co_search)
78+
79+
if get_state(co_search) == xinv_st.unlinked:
80+
#we may have to apply an additional lookup in the group unknown multiindex
81+
if group_id_co is None and group_seq_co is None:
82+
group_id_co,group_seq_co=find_xinv_group_coords(dsneq)
83+
notfound.extend([(coname,i) for i in find_ilocs(dsneq,coname,notfnd)])
84+
found.extend([(coname,i) for i in find_ilocs(dsneq,coname,fnd)])
85+
remaining.extend([(coname,i) for i in find_ilocs(dsneq,coname,remng)])
86+
87+
elif get_state(co_search) == xinv_st.linked:
88+
notfound.extend(notfnd)
89+
found.extend(fnd)
90+
remaining.extend(remng)
91+
else:
92+
raise ValueError(f"Reduction coordinate {coname} has no valid link state")
93+
94+
95+
96+
#index vector of the found parameters
97+
idxfound=find_ilocs(dsneq,unkdim,found) if len(found) > 0 else None
98+
idxnotfound=find_ilocs(dsneq,unkdim,notfound) if len(notfound) > 0 else None
99+
idxremaining=find_ilocs(dsneq,unkdim,remaining) if len(remaining) > 0 else None
100+
if sort:
101+
if idxfound is not None:
102+
idxfound=np.sort(idxfound)
103+
if idxremaining is not None:
104+
idxremaining=np.sort(idxremaining)
105+
if idxnotfound is not None:
106+
idxnotfound=np.sort(idxnotfound)
107+
108+
return idxfound,idxremaining,idxnotfound

src/xinv/linalg/inplace.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99

1010
from scipy.linalg import cholesky
1111
from scipy.linalg.lapack import dpotri
12-
from scipy.linalg.blas import dtrsm,dsyrk
13-
12+
from scipy.linalg.blas import dtrsm,dsyrk,dsymm
1413

14+
import ctypes
1515
from enum import Enum
1616

1717
class MemLayout(Enum):
@@ -184,3 +184,71 @@ def dsyrk_inplace(N,A,trans=0,beta=0.0,alpha=1.0):
184184
A[()]=adat
185185

186186
return N
187+
188+
def dsymm_inplace(A:xr.DataArray,B:xr.DataArray,C:xr.DataArray,alpha=1.0,beta=0.0):
189+
"""
190+
Symmetric matrix times matrix multiplication
191+
C= alpha*A*B +beta*C
192+
"""
193+
#
194+
# C := alpha*A*B + beta*C, (side=0)
195+
#or
196+
# C := alpha*B*A + beta*C, (side=1)
197+
198+
199+
#some basic checks
200+
if A.shape[0] != A.shape[1]:
201+
raise ValueError("Matrix A must be square")
202+
if A.shape[1] != B.shape[0]:
203+
raise ValueError("Inner dimensions of A and B do not match")
204+
if A.shape[0] != C.shape[0]:
205+
raise ValueError("Output matrix C rows do not match matrix A")
206+
if B.shape[1] != C.shape[1]:
207+
raise ValueError("Columns of C do not match that of B")
208+
209+
210+
211+
restore=False
212+
c_layout=memlayout(C.data)
213+
214+
if c_layout == MemLayout.F_cont:
215+
cdat=C.data
216+
side=0 #0 means left
217+
elif c_layout == MemLayout.C_cont:
218+
cdat=C.data.T
219+
#switch side to allow for in place operation
220+
side=1
221+
else:
222+
xinvlogger.warning("C matrix is not C or F contiguous, applying copy and restore")
223+
cdat=C.data.copy(order='F')
224+
side=0 #0 means left
225+
restore=True
226+
227+
lower=islower(A)
228+
A_layout=memlayout(A.data)
229+
if A_layout == MemLayout.C_cont:
230+
#To prevent an additional copy by f2py we can fake a F contigous array by transposing and switching lower/upper
231+
lower=1-lower
232+
adat=A.data.T
233+
else:
234+
adat=A.data
235+
236+
237+
if side == 1:
238+
bdat=B.data.T
239+
else:
240+
bdat=B.data
241+
242+
b_layout=memlayout(B.data)
243+
244+
245+
if b_layout != c_layout:
246+
xinvlogger.warning("B matrix contiguousness is inconsistent with C, a copy will be made by f2py")
247+
#dsymm(alpha, a, b[, beta, c, side, lower, overwrite_c])
248+
dsymm(alpha,adat,bdat,beta=beta,c=cdat,side=side,lower=lower,overwrite_c=1)
249+
250+
251+
if restore:
252+
C.data[()]=cdat
253+
254+
return C #although changed in place

src/xinv/neq/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,6 @@
33
from .solve import *
44
from .add import *
55
from .reduce import *
6+
from .set_apriori import *
7+
from .fix import *
8+

src/xinv/neq/fix.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
## Permissions: See the xinv license file https://raw.githubusercontent.com/strawpants/xinv/master/LICENSE
2+
## Copyright (c) 2025 Roelof Rietbroek, r.rietbroek@utwente.nl
3+
4+
5+
from xinv.core.attrs import find_component,xinv_tp,find_xinv_group_coords,get_xunk_size_coname
6+
7+
from xinv.core.tools import find_unk_idx
8+
import numpy as np
9+
from xinv.core.logging import xinvlogger
10+
11+
12+
def ifix(dsneq,idx,keep=False):
13+
"""
14+
Fix/remove parameters in a normal equation system to their apriori values using indexing
15+
Parameters:
16+
-----------
17+
dsneq: xarray.Dataset
18+
Dataset containing the normal equation system to be fixed
19+
idx: list or np.ndarray
20+
index of the coordinates to fix or keep in the system
21+
22+
"""
23+
24+
u_sz,unkdim=get_xunk_size_coname(dsneq)
25+
#compute the complementary index
26+
if keep:
27+
idxkeep=idx
28+
if idxkeep.dtype == bool:
29+
idxfix=~idxkeep
30+
else:
31+
idxfix=~np.isin(np.arange(u_sz),idxkeep)
32+
else:
33+
idxfix=idx
34+
if idxfix.dtype == bool:
35+
idxkeep=~idxfix
36+
else:
37+
idxkeep=~np.isin(np.arange(u_sz),idxfix)
38+
39+
o_dsneq=dsneq.sel({unkdim:idxkeep,unkdim+'_':idxkeep})
40+
io_npara=find_component(o_dsneq,xinv_tp.npara)
41+
#update amount of unknown parameters
42+
io_npara[()]-=len(idxfix)
43+
44+
return o_dsneq
45+
46+
47+
def fix(dsneq, keep=False,**kwargs):
48+
"""
49+
Fix/remove parameters from a Normal equation system using coordinate labelling
50+
Parameters:
51+
-----------
52+
dsneq: xarray.Dataset
53+
Dataset containing the normal equation system to be fixed
54+
keep: bool, optional
55+
If True, the parameters are kept instead of fixed. The default is False.
56+
**kwargs:
57+
keyword arguments with the dimension name as key and a list of coordinate labels to be fixed/removed from the system
58+
coord1 = fixlabels1 , .. coord2 = fixlabels2
59+
60+
"""
61+
62+
idxfound,idxremaining,idxnotfound=find_unk_idx(dsneq,**kwargs)
63+
if idxnotfound is not None:
64+
xinvlogger.warning(f"Fix parameters contain values {idxnotfound} which are not found in the input normal equation system, ignoring those")
65+
66+
if (not keep and idxremaining is None) or (keep and idxfound is None):
67+
xinvlogger.warning("Nothing to fix, returning input")
68+
return dsneq
69+
elif (not keep and idxfound is None) or (keep and idxremaining is None):
70+
#cannot fix all unknowns
71+
raise ValueError("Fix parameters contain all unknown parameters, cannot fix all unknowns")
72+
73+
return ifix(dsneq,idx=idxfound,keep=keep)
74+
75+
def groupfix(dsneq,groupname,keep=False):
76+
"""
77+
Fix/remove by groupname a group of parameters from a normal equation system
78+
Parameters
79+
----------
80+
dsneq : xr.Dataset
81+
Dataset containing the normal equation system to be fixed/removed
82+
groupname : str
83+
The groupname of the parameters to be fixed/removed
84+
keep : bool, optional
85+
If True, the group parameters are kept instead of fixed. The default is False.
86+
"""
87+
88+
89+
#test whether the groupname actually exists
90+
grpid_co,grpseq_co=find_xinv_group_coords(dsneq)
91+
if grpid_co is None or grpseq_co is None:
92+
raise ValueError("No group coordinates found in the normal equation system")
93+
94+
if groupname not in grpid_co.data:
95+
raise ValueError(f"Groupname {groupname} not found in the normal equation system")
96+
97+
idx=grpid_co.data == groupname
98+
return ifix(dsneq,idx=idx,keep=keep)
99+
100+

src/xinv/neq/neq.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,6 @@ def zeros(rhsdims,coords,lower=0,norder='C'):
5959

6060

6161

62-
def reduce(dsneq:xr.Dataset, idx):
63-
#tbd reduce (implicitly solve) variables from a normal equation system spanned by idx
64-
raise NotImplementedError("Reduce operation not yet implemented")
65-
66-
def fix(dsneq:xr.Dataset, idx):
67-
#tbd fix and (remove) solve) variables from a normal equation system spanned by idx
68-
raise NotImplementedError("Fix operation not yet implemented")
69-
70-
def set_apriori(dsneq:xr.Dataset, daapri:xr.DataArray):
71-
#tbd set/change apriori values in a normal equation system
72-
raise NotImplementedError("Set apriori values not yet implemented")
7362

7463

7564

src/xinv/neq/reduce.py

Lines changed: 9 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from xinv.core.attrs import find_xinv_unk_coord,find_neq_components,islower,xinv_st,get_state,get_type,xinv_tp,find_xinv_group_coords,get_xunk_size_coname
66

7-
from xinv.core.tools import find_overlap_coords,find_ilocs
7+
from xinv.core.tools import find_overlap_coords,find_ilocs,find_unk_idx
88
import numpy as np
99
import xarray as xr
1010
from xinv.core.logging import xinvlogger
@@ -128,59 +128,18 @@ def reduce(dsneq, keep=False,**kwargs):
128128
129129
"""
130130

131+
idxfound,idxremaining,idxnotfound=find_unk_idx(dsneq,**kwargs)
132+
if idxnotfound is not None:
133+
xinvlogger.warning(f"Reduction parameters contain values {idxnotfound} which are not found in the input normal equation system, ignoring those")
131134

132-
xunk_co=find_xinv_unk_coord(dsneq)
133-
134-
unkdim=xunk_co.dims[0]
135-
136-
group_id_co=None
137-
group_seq_co=None
138-
139-
reduniq=[]
140-
common=[]
141-
neuniq=[]
142-
for coname,redparams in kwargs.items():
143-
co_search=dsneq[coname]
144-
if get_type(co_search) != xinv_tp.unk_co:
145-
raise ValueError(f"Missing xinv_type: Reduction coordinate {coname} has no valid relation with unknown coordinate {xunk_co.name}")
146-
dimname=co_search.dims[0]
147-
if redparams is not type(xr.DataArray):
148-
#turn into DataArray
149-
redparams=xr.DataArray(redparams,dims=dimname)
150-
#find unique and overlapping coordinates over the unknown dimension
151-
redu,com,neun=find_overlap_coords(redparams,co_search)
152-
153-
if get_state(co_search) == xinv_st.unlinked:
154-
#we may have to apply an additional lookup in the group unknown multiindex
155-
if group_id_co is None and group_seq_co is None:
156-
group_id_co,group_seq_co=find_xinv_group_coords(dsneq)
157-
common.extend([(coname,i) for i in find_ilocs(dsneq,coname,com)])
158-
reduniq.extend([(coname,i) for i in find_ilocs(dsneq,coname,redu)])
159-
neuniq.extend([(coname,i) for i in find_ilocs(dsneq,coname,neun)])
160-
161-
elif get_state(co_search) == xinv_st.linked:
162-
reduniq.extend(redu)
163-
common.extend(com)
164-
neuniq.extend(neun)
165-
else:
166-
raise ValueError(f"Reduction coordinate {coname} has no valid link state")
167-
168-
169-
170-
if len(reduniq) > 0:
171-
xinvlogger.warning(f"Reduction parameters contain values {reduniq} which are not found in the input normal equation system, ignoring those")
172-
173-
if len(neuniq) == 0:
174-
raise ValueError("Reduction parameters contain all unknown parameters, cannot reduce all unknowns")
175-
176-
if len(common) == 0:
135+
if (not keep and idxremaining is None) or (keep and idxfound is None):
177136
xinvlogger.warning("Nothing to reduce, returning input")
178137
return dsneq
179-
180-
#index vector of the to be reduced parameters
181-
idxreduce=find_ilocs(dsneq,unkdim,common)
138+
elif (not keep and idxfound is None) or (keep and idxremaining is None):
139+
#cannot reduce all unknowns
140+
raise ValueError("Reduction parameters contain all unknown parameters, cannot reduce all unknowns")
182141

183-
return ireduce(dsneq,idxreduce,keep)
142+
return ireduce(dsneq,idxfound,keep)
184143

185144
def groupreduce(dsneq,groupname,keep=False):
186145
"""

0 commit comments

Comments
 (0)