@@ -12,24 +12,30 @@ def __init__(self, mu_arr=None, Sigma_arr=None):
1212 self .n_var = len (self .mu_arr [0 ])
1313 self .problem_name = "MOKL"
1414
15- def _evaluate_torch (self , prefs : torch .Tensor ):
15+
16+ def _evaluate_torch (self , prefs_arr : torch .Tensor ):
1617 # prefs are the coefficients.
17- mu_arr_arr = [p * mu for p , mu in zip (prefs , self .mu_arr )]
18- Sigma_arr_arr = [p * Sigma for p , Sigma in zip (prefs , self .Sigma_arr )]
19- mu = torch .sum ( torch .stack (mu_arr_arr ), axis = 0 )
20- Sigma = torch .sum ( torch .stack (Sigma_arr_arr ), axis = 0 )
21-
22- f_arr = []
23- for obj_idx in range (self .n_obj ):
24- mu_i = self .mu_arr [obj_idx ]
25- Sigma_i = self .Sigma_arr [obj_idx ]
26-
27- term1 = torch .log (torch .det (Sigma_i )) - torch .log (torch .det (Sigma_i ))
28- term2 = (mu - mu_i ) @ torch .inverse (Sigma_i ) @ (mu - mu_i )
29- term3 = torch .trace (torch .inverse (Sigma_i ) @ Sigma )
30- fi = 0.5 * (term1 + term2 + term3 - self .n_var )
31- f_arr .append (fi )
32- return torch .stack (f_arr )
18+ f_arr_all = []
19+ for prefs in prefs_arr :
20+ Sigma_inverse_arr_arr = [p * torch .inverse (Sigma_ ) for p , Sigma_ in zip (prefs , self .Sigma_arr )]
21+ Sigma = torch .inverse (torch .sum ( torch .stack (Sigma_inverse_arr_arr ), axis = 0 ))
22+
23+ mu_arr_arr = [p * torch .inverse (Sigma_ ) @ mu
24+ for p , mu , Sigma_ in zip (prefs , self .mu_arr , self .Sigma_arr )]
25+ mu = Sigma @ torch .sum ( torch .stack (mu_arr_arr ), axis = 0 )
26+ f_arr = []
27+
28+ for obj_idx in range (self .n_obj ):
29+ mu_i = self .mu_arr [obj_idx ]
30+ Sigma_i = self .Sigma_arr [obj_idx ]
31+ term1 = torch .log (torch .det (Sigma_i )) - torch .log (torch .det (Sigma ))
32+ term2 = (mu - mu_i ) @ torch .inverse (Sigma_i ) @ (mu - mu_i )
33+ term3 = torch .trace (torch .inverse (Sigma_i ) @ Sigma )
34+ fi = 0.5 * (term1 + term2 + term3 - self .n_var )
35+ f_arr .append (fi )
36+ f_arr = torch .stack (f_arr )
37+ f_arr_all .append (f_arr )
38+ return torch .stack (f_arr_all )
3339
3440
3541if __name__ == '__main__' :
0 commit comments