2626
2727N1 = 100
2828N2 = 10
29+ N3 = 1000
30+
2931GRID_N = 1000
3032GRID_RG = 100
3133
34+ # parameters to test whether sampling adheres to weights
35+ COMPONENT_WEIGHTS = [0.9 , 0.05 ]
36+ LOC1 = [10.0 , 10.0 ]
37+ LOC2 = [0.0 , 0.0 ]
38+ COV1 = [1.0 , 1.0 ]
39+ COV2 = [1.0 , 1.0 ]
40+ THRESHOLD = 5.0
41+
3242
3343class TestKDE (unittest .TestCase ):
3444 def setUp (self ):
@@ -147,16 +157,15 @@ def fit_and_eval(X, X_new, bandwidth):
147157
148158 def test_sampling_adheres_to_weights (self ,):
149159 # test if the sample_weights passed in fit(X, sample_weights=...) are respected on sampling
150- n_samples = 2000
151-
152160 # create a GMM with 2 components with weights pi
153- pi = torch .tensor ([0.9 , 0.05 ]) # 90 % for component 1, 0.5 % for component 2 (does not need to sum to 1)
154- loc1 = torch .tensor ([10. ,10 ])
155- cov1 = torch .diag (torch .tensor ([1. ,1 ]))
156- loc2 = torch .tensor ([0. ,0 ])
157- cov2 = torch .diag (torch .tensor ([1. ,1 ]))
158- weights1 = torch .ones ((n_samples ,))* pi [0 ]
159- weights2 = torch .ones ((n_samples ,))* pi [1 ]
161+ pi = torch .tensor (COMPONENT_WEIGHTS ) # 90 % for component 1, 0.5 % for component 2 (does not need to sum to 1)
162+ loc1 = torch .tensor (LOC1 )
163+ cov1 = torch .diag (torch .tensor (COV1 ))
164+ loc2 = torch .tensor (LOC2 )
165+ cov2 = torch .diag (torch .tensor (COV2 ))
166+
167+ weights1 = torch .ones ((N3 ,))* pi [0 ]
168+ weights2 = torch .ones ((N3 ,))* pi [1 ]
160169
161170 locs = torch .stack ([loc1 , loc2 ]) # Shape: [n_components, event_shape] = [2, 2]
162171 covs = torch .stack ([cov1 , cov2 ]) # Shape: [n_components, event_shape, event_shape] = [2, 2, 2]
@@ -166,32 +175,32 @@ def test_sampling_adheres_to_weights(self,):
166175 covariance_matrix = covs
167176 )
168177
169- # 2. Create the mixing distribution (Categorical)
178+ # Create the mixing distribution (Categorical)
170179 # pi is interpreted as weights in linear-scale
171180 mixing_distribution = dist .Categorical (probs = pi )
172181
173- # 3. Create the MixtureSameFamily distribution
182+ # Create the MixtureSameFamily distribution
174183 # This combines the mixing and component distributions
175184 gmm = dist .MixtureSameFamily (
176185 mixture_distribution = mixing_distribution ,
177186 component_distribution = component_distribution
178187 )
179188
180189
181- X = component_distribution .sample ((n_samples ,))
190+ X = component_distribution .sample ((N3 ,))
182191 X1 = X [:,0 ,:]
183192 X2 = X [:,1 ,:]
184193
185194 kde = torchkde .KernelDensity (bandwidth = .5 , kernel = 'gaussian' ) # create kde object with isotropic bandwidth matrix
186195 kde .fit (torch .concat ((X1 , X2 ), dim = 0 ), sample_weight = torch .concat ((weights1 , weights2 ), dim = 0 )) # fit kde to weighted data
187196
188- samples_from_kde = kde .sample (n_samples )
189- samples_from_gmm = gmm .sample ((n_samples ,))
197+ samples_from_kde = kde .sample (N1 )
198+ samples_from_gmm = gmm .sample ((N1 ,))
199+
200+ component_1_fraction_kde = torch .count_nonzero (torch .where (samples_from_kde [:,0 ] > THRESHOLD , 1.0 , 0.0 )) / N3
201+ component_1_fraction_gmm = torch .count_nonzero (torch .where (samples_from_gmm [:,0 ] > THRESHOLD , 1.0 , 0.0 )) / N3
190202
191- component_1_fraction_kde = torch .count_nonzero (torch .where (samples_from_kde [:,0 ] > 5. , 1.0 , 0.0 )) / n_samples
192- component_1_fraction_gmm = torch .count_nonzero (torch .where (samples_from_gmm [:,0 ] > 5. , 1.0 , 0.0 )) / n_samples
193- print (component_1_fraction_kde / component_1_fraction_gmm )
194- self .assertTrue (0.9 < (component_1_fraction_kde / component_1_fraction_gmm ) < 1.1 , "Component weights must be considered on sampling." )
203+ self .assertTrue (1.0 - TOLERANCE < (component_1_fraction_kde / component_1_fraction_gmm ) < 1.0 + TOLERANCE , "Component weights must be considered on sampling." )
195204
196205
197206def sample_from_gaussian (dim , N ):
0 commit comments