Skip to content

Commit 5e9f705

Browse files
authored
Merge pull request #507 from alan-turing-institute/hm_bug
small bug in HM
2 parents 2465e7b + c85d899 commit 5e9f705

File tree

3 files changed

+273
-77
lines changed

3 files changed

+273
-77
lines changed

autoemulate/history_matching.py

Lines changed: 48 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -169,23 +169,13 @@ def sample_nroy(
169169

170170
# Need to handle discontinuous NROY spaces
171171
# i.e., region within min/max bounds is RO
172-
valid_samples = np.empty((0, nroy_samples.shape[1]))
173-
while len(valid_samples) < n_samples:
174-
# Generate candidates
175-
candidate_samples = np.random.uniform(
176-
min_bounds, max_bounds, size=(n_samples, nroy_samples.shape[1])
177-
)
178-
179-
# Filter valid samples based on implausibility and concatenate
180-
implausibility = self.calculate_implausibility(candidate_samples)
181-
valid_candidates = candidate_samples[implausibility["NROY"]]
182-
valid_samples = np.concatenate((valid_samples, valid_candidates), axis=0)
183172

184-
# Only return required number of samples
185-
if len(valid_samples) > n_samples:
186-
valid_samples = valid_samples[:n_samples]
173+
# Generate candidates
174+
candidate_samples = np.random.uniform(
175+
min_bounds, max_bounds, size=(n_samples, nroy_samples.shape[1])
176+
)
187177

188-
return valid_samples
178+
return candidate_samples
189179

190180
def predict(
191181
self,
@@ -298,13 +288,22 @@ def run(
298288

299289
with tqdm(total=n_waves, desc="History Matching", unit="wave") as pbar:
300290
for wave in range(n_waves):
291+
# CHECK IF WE HAVE SAMPLES TO PROCESS
292+
if len(current_samples) == 0:
293+
print(f"Wave {wave}: No valid samples found, skipping...")
294+
pbar.update(1)
295+
continue
296+
301297
# Run wave using batch processing
302298
pred_means, pred_vars, successful_samples = self.predict(
303299
x=current_samples,
304300
# Emulate predictions unless emulator_predict=False
305301
emulator=emulator if emulator_predict else None,
306302
)
307-
303+
if len(successful_samples) == 0:
304+
print(f"Wave {wave}: All simulations failed, skipping...")
305+
pbar.update(1)
306+
continue
308307
# Calculate implausibility in batch
309308
implausibility = self.calculate_implausibility(pred_means, pred_vars)
310309

@@ -328,18 +327,39 @@ def run(
328327
emulator, successful_samples, pred_means
329328
)
330329

331-
# Generate new samples for next wave
332-
if wave < n_waves - 1:
333-
if nroy_samples.size > 0:
334-
current_samples = self.sample_nroy(
335-
nroy_samples, n_samples_per_wave
336-
)
337-
else:
338-
# If no NROY points, sample from full space
339-
current_samples = self.simulator.sample_inputs(
340-
n_samples_per_wave
341-
)
342-
330+
# Generate new samples for next wave
331+
if wave < n_waves - 1:
332+
if nroy_samples.size > 0:
333+
# Sample candidates
334+
candidate_samples = self.sample_nroy(
335+
nroy_samples, n_samples_per_wave
336+
)
337+
338+
# Filter candidates using emulator before simulation
339+
if not emulator_predict and emulator is not None:
340+
pred_means, pred_vars = emulator.predict(
341+
candidate_samples, return_std=True
342+
)
343+
pred_vars = pred_vars**2
344+
345+
# Ensure correct shape for single output case
346+
if len(pred_means.shape) == 1:
347+
pred_means = pred_means.reshape(-1, 1)
348+
pred_vars = pred_vars.reshape(-1, 1)
349+
350+
implausibility = self.calculate_implausibility(
351+
pred_means, pred_vars
352+
)
353+
current_samples = candidate_samples[
354+
implausibility["NROY"]
355+
]
356+
else:
357+
current_samples = candidate_samples
358+
else:
359+
# If no NROY points, sample from full space
360+
current_samples = self.simulator.sample_inputs(
361+
n_samples_per_wave
362+
)
343363
pbar.update(1)
344364

345365
# Concatenate all samples and implausibility scores

autoemulate/history_matching_dashboard.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,26 +1209,8 @@ def display(self):
12091209

12101210
heading = widgets.HTML(value="<h2>History Matching Dashboard</h2>")
12111211

1212-
instructions = widgets.HTML(
1213-
value="""
1214-
<div style="background-color: #f8f9fa; padding: 15px; border-radius: 5px; margin-bottom: 15px;">
1215-
<h3>Instructions:</h3>
1216-
<ul>
1217-
<li><strong>Plot Type</strong>: Select the type of visualization from the dropdown.</li>
1218-
<li><strong>Threshold</strong>: Adjust the implausibility threshold for NROY classification.</li>
1219-
<li><strong>Parameters</strong>: Select parameters to visualize (availability depends on plot type).</li>
1220-
<li><strong>Show only NROY points</strong>: When available, filters to show only viable parameter combinations.</li>
1221-
<li><strong>Update Plot</strong>: Click to regenerate the visualization with current settings.</li>
1222-
</ul>
1223-
<p><em>Note: Some controls are only available for certain plot types.</em></p>
1224-
</div>
1225-
"""
1226-
)
1227-
12281212
# Display the heading and instructions first
12291213
display(heading)
1230-
display(instructions)
1231-
12321214
display(self.main_layout)
12331215
# Initialize the first plot
12341216
self._update_plot(None)

0 commit comments

Comments
 (0)