Skip to content

Commit d96b748

Browse files
shrutipatel31facebook-github-bot
authored andcommitted
Add Progression Plots for MapMetric experiments to ResultsAnalysis (facebook#4705)
Summary: This diff adds learning curve visualization (progression plots) to ResultsAnalysis for experiments with MapData and MapMetrics. Differential Revision: D89776181 Privacy Context Container: L1307644
1 parent 3af75bc commit d96b748

File tree

3 files changed

+78
-0
lines changed

3 files changed

+78
-0
lines changed

ax/analysis/plotly/progression.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,15 @@
2626
from plotly import graph_objects as go
2727
from pyre_extensions import none_throws, override
2828

29+
PROGRESSION_CARDGROUP_TITLE = "Learning Curves: Metric progression over trials"
30+
PROGRESSION_CARDGROUP_SUBTITLE = (
31+
"These plots show curve metrics (learning curves) that track the evolution of "
32+
"each metric over the course of the experiment. The plots display how metrics "
33+
"change during trial execution, both by progression (e.g., epochs or steps) "
34+
"and by wallclock time. This is useful for monitoring optimization progress and "
35+
"informing early stopping decisions."
36+
)
37+
2938

3039
@final
3140
class ProgressionPlot(Analysis):

ax/analysis/results.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
from ax.analysis.best_trials import BestTrials
1414
from ax.analysis.plotly.arm_effects import ArmEffectsPlot
1515
from ax.analysis.plotly.bandit_rollout import BanditRollout
16+
from ax.analysis.plotly.progression import (
17+
PROGRESSION_CARDGROUP_SUBTITLE,
18+
PROGRESSION_CARDGROUP_TITLE,
19+
ProgressionPlot,
20+
)
1621
from ax.analysis.plotly.scatter import (
1722
SCATTER_CARDGROUP_SUBTITLE,
1823
SCATTER_CARDGROUP_TITLE,
@@ -25,6 +30,8 @@
2530
from ax.core.arm import Arm
2631
from ax.core.batch_trial import BatchTrial
2732
from ax.core.experiment import Experiment
33+
from ax.core.map_data import MapData
34+
from ax.core.map_metric import MapMetric
2835
from ax.core.outcome_constraint import ScalarizedOutcomeConstraint
2936
from ax.core.trial_status import TrialStatus
3037
from ax.core.utils import is_bandit_experiment
@@ -240,6 +247,33 @@ def compute(
240247
adapter=adapter,
241248
)
242249

250+
# Compute progression plots for MapMetrics (learning curves)
251+
progression_group = None
252+
data = experiment.lookup_data()
253+
has_map_data = isinstance(data, MapData)
254+
metrics = experiment.metrics.values()
255+
map_metrics = [m for m in metrics if isinstance(m, MapMetric)]
256+
if has_map_data and len(map_metrics) > 0:
257+
map_metric_names = [m.name for m in map_metrics]
258+
progression_cards = [
259+
ProgressionPlot(
260+
metric_name=metric_name, by_wallclock_time=by_wallclock_time
261+
).compute_or_error_card(
262+
experiment=experiment,
263+
generation_strategy=generation_strategy,
264+
adapter=adapter,
265+
)
266+
for metric_name in map_metric_names
267+
for by_wallclock_time in (False, True)
268+
]
269+
if progression_cards:
270+
progression_group = AnalysisCardGroup(
271+
name="ProgressionAnalysis",
272+
title=PROGRESSION_CARDGROUP_TITLE,
273+
subtitle=PROGRESSION_CARDGROUP_SUBTITLE,
274+
children=progression_cards,
275+
)
276+
243277
return self._create_analysis_card_group(
244278
title=RESULTS_CARDGROUP_TITLE,
245279
subtitle=RESULTS_CARDGROUP_SUBTITLE,
@@ -252,6 +286,7 @@ def compute(
252286
bandit_rollout_card,
253287
best_trials_card,
254288
utility_progression_card,
289+
progression_group,
255290
summary,
256291
)
257292
if child is not None

ax/analysis/tests/test_results.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
get_experiment_with_scalarized_objective_and_outcome_constraint,
3636
get_offline_experiments,
3737
get_online_experiments,
38+
get_test_map_data_experiment,
3839
)
3940
from ax.utils.testing.mock import mock_botorch_optimize
4041
from ax.utils.testing.modeling_stubs import get_default_generation_strategy_at_MBM_node
@@ -499,6 +500,39 @@ def test_offline_experiments(self) -> None:
499500
self.assertIsNotNone(card_group)
500501
self.assertGreater(len(card_group.children), 0)
501502

503+
@mock_botorch_optimize
504+
def test_compute_with_map_data_includes_progression_plots(self) -> None:
505+
# Setup: Create experiment with MapData and MapMetrics
506+
experiment = get_test_map_data_experiment(
507+
num_trials=3, num_fetches=2, num_complete=2
508+
)
509+
generation_strategy = get_default_generation_strategy_at_MBM_node(
510+
experiment=experiment
511+
)
512+
513+
# Execute: Compute ResultsAnalysis
514+
card_group = ResultsAnalysis().compute(
515+
experiment=experiment,
516+
generation_strategy=generation_strategy,
517+
)
518+
519+
# Assert: ProgressionAnalysis group exists with children
520+
progression_group = None
521+
for child in card_group.children:
522+
if child.name == "ProgressionAnalysis":
523+
progression_group = child
524+
break
525+
526+
self.assertIsNotNone(
527+
progression_group,
528+
"ProgressionAnalysis group should be present for MapMetric experiments",
529+
)
530+
self.assertGreater(
531+
len(assert_is_instance(progression_group, AnalysisCardGroup).children),
532+
0,
533+
"ProgressionAnalysis group should have at least one progression plot",
534+
)
535+
502536

503537
class TestArmEffectsPair(TestCase):
504538
@mock_botorch_optimize

0 commit comments

Comments
 (0)