@@ -169,23 +169,13 @@ def sample_nroy(
169169
170170 # Need to handle discontinuous NROY spaces
171171 # i.e., region within min/max bounds is RO
172- valid_samples = np .empty ((0 , nroy_samples .shape [1 ]))
173- while len (valid_samples ) < n_samples :
174- # Generate candidates
175- candidate_samples = np .random .uniform (
176- min_bounds , max_bounds , size = (n_samples , nroy_samples .shape [1 ])
177- )
178-
179- # Filter valid samples based on implausibility and concatenate
180- implausibility = self .calculate_implausibility (candidate_samples )
181- valid_candidates = candidate_samples [implausibility ["NROY" ]]
182- valid_samples = np .concatenate ((valid_samples , valid_candidates ), axis = 0 )
183172
184- # Only return required number of samples
185- if len (valid_samples ) > n_samples :
186- valid_samples = valid_samples [:n_samples ]
173+ # Generate candidates
174+ candidate_samples = np .random .uniform (
175+ min_bounds , max_bounds , size = (n_samples , nroy_samples .shape [1 ])
176+ )
187177
188- return valid_samples
178+ return candidate_samples
189179
190180 def predict (
191181 self ,
@@ -298,13 +288,22 @@ def run(
298288
299289 with tqdm (total = n_waves , desc = "History Matching" , unit = "wave" ) as pbar :
300290 for wave in range (n_waves ):
291+ # CHECK IF WE HAVE SAMPLES TO PROCESS
292+ if len (current_samples ) == 0 :
293+ print (f"Wave { wave } : No valid samples found, skipping..." )
294+ pbar .update (1 )
295+ continue
296+
301297 # Run wave using batch processing
302298 pred_means , pred_vars , successful_samples = self .predict (
303299 x = current_samples ,
304300 # Emulate predictions unless emulator_predict=False
305301 emulator = emulator if emulator_predict else None ,
306302 )
307-
303+ if len (successful_samples ) == 0 :
304+ print (f"Wave { wave } : All simulations failed, skipping..." )
305+ pbar .update (1 )
306+ continue
308307 # Calculate implausibility in batch
309308 implausibility = self .calculate_implausibility (pred_means , pred_vars )
310309
@@ -328,18 +327,39 @@ def run(
328327 emulator , successful_samples , pred_means
329328 )
330329
331- # Generate new samples for next wave
332- if wave < n_waves - 1 :
333- if nroy_samples .size > 0 :
334- current_samples = self .sample_nroy (
335- nroy_samples , n_samples_per_wave
336- )
337- else :
338- # If no NROY points, sample from full space
339- current_samples = self .simulator .sample_inputs (
340- n_samples_per_wave
341- )
342-
330+ # Generate new samples for next wave
331+ if wave < n_waves - 1 :
332+ if nroy_samples .size > 0 :
333+ # Sample candidates
334+ candidate_samples = self .sample_nroy (
335+ nroy_samples , n_samples_per_wave
336+ )
337+
338+ # Filter candidates using emulator before simulation
339+ if not emulator_predict and emulator is not None :
340+ pred_means , pred_vars = emulator .predict (
341+ candidate_samples , return_std = True
342+ )
343+ pred_vars = pred_vars ** 2
344+
345+ # Ensure correct shape for single output case
346+ if len (pred_means .shape ) == 1 :
347+ pred_means = pred_means .reshape (- 1 , 1 )
348+ pred_vars = pred_vars .reshape (- 1 , 1 )
349+
350+ implausibility = self .calculate_implausibility (
351+ pred_means , pred_vars
352+ )
353+ current_samples = candidate_samples [
354+ implausibility ["NROY" ]
355+ ]
356+ else :
357+ current_samples = candidate_samples
358+ else :
359+ # If no NROY points, sample from full space
360+ current_samples = self .simulator .sample_inputs (
361+ n_samples_per_wave
362+ )
343363 pbar .update (1 )
344364
345365 # Concatenate all samples and implausibility scores
0 commit comments