Skip to content

Commit 7554ae4

Browse files
shrutipatel31facebook-github-bot
authored andcommitted
Update Complexity Rating Healthcheck (facebook#4707)
Summary: This diff refactors ComplexityRatingAnalysis to decouple it from OrchestratorOptions by accepting individual configuration parameters directly instead of the entire options object. Differential Revision: D89778632 Privacy Context Container: L1307644
1 parent 95b7f94 commit 7554ae4

File tree

2 files changed

+106
-111
lines changed

2 files changed

+106
-111
lines changed

ax/analysis/healthcheck/complexity_rating.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616
HealthcheckStatus,
1717
)
1818
from ax.core.experiment import Experiment
19+
from ax.early_stopping.strategies import BaseEarlyStoppingStrategy
1920
from ax.generation_strategy.generation_strategy import GenerationStrategy
20-
from ax.service.orchestrator import OrchestratorOptions
21+
from ax.global_stopping.strategies.base import BaseGlobalStoppingStrategy
2122
from ax.utils.common.complexity_utils import (
2223
check_if_in_standard,
2324
DEFAULT_TIER_MESSAGES,
@@ -61,16 +62,17 @@ class ComplexityRatingAnalysis(Analysis):
6162

6263
def __init__(
6364
self,
64-
options: OrchestratorOptions | None = None,
6565
tier_metadata: dict[str, Any] | None = None,
6666
tier_messages: TierMessages = DEFAULT_TIER_MESSAGES,
67+
early_stopping_strategy: BaseEarlyStoppingStrategy | None = None,
68+
global_stopping_strategy: BaseGlobalStoppingStrategy | None = None,
69+
tolerated_trial_failure_rate: float | None = None,
70+
max_pending_trials: int | None = None,
71+
min_failed_trials_for_failure_rate_check: int | None = None,
6772
) -> None:
6873
"""Initialize the ComplexityRatingAnalysis.
6974
7075
Args:
71-
options: The orchestrator options used for the optimization.
72-
Required to evaluate early stopping, global stopping, and
73-
failure rate settings.
7476
tier_metadata: Additional tier-related metadata from the orchestrator.
7577
Supported keys:
7678
- 'user_supplied_max_trials': Maximum number of trials.
@@ -82,12 +84,27 @@ def __init__(
8284
generic messages suitable for most users. Pass a custom TierMessages
8385
instance to provide tool-specific descriptions, support SLAs,
8486
links to docs, or contact information.
87+
early_stopping_strategy: The early stopping strategy, if any. Used to
88+
determine if early stopping is enabled. Defaults to None.
89+
global_stopping_strategy: The global stopping strategy, if any. Used to
90+
determine if global stopping is enabled. Defaults to None.
91+
tolerated_trial_failure_rate: Fraction of trials allowed to fail without
92+
the whole optimization ending. Default value used is 0.5.
93+
max_pending_trials: Maximum number of pending trials. Default used is 10.
94+
min_failed_trials_for_failure_rate_check: Minimum failed trials before
95+
failure rate is checked. Default value used is 5.
8596
"""
86-
self.options = options
8797
self.tier_metadata: dict[str, Any] = (
8898
tier_metadata if tier_metadata is not None else {}
8999
)
90100
self.tier_messages = tier_messages
101+
self.early_stopping_strategy = early_stopping_strategy
102+
self.global_stopping_strategy = global_stopping_strategy
103+
self.tolerated_trial_failure_rate = tolerated_trial_failure_rate
104+
self.max_pending_trials = max_pending_trials
105+
self.min_failed_trials_for_failure_rate_check = (
106+
min_failed_trials_for_failure_rate_check
107+
)
91108

92109
@override
93110
def validate_applicable_state(
@@ -98,11 +115,6 @@ def validate_applicable_state(
98115
) -> str | None:
99116
if experiment is None:
100117
return "Experiment is required for ComplexityRatingAnalysis."
101-
if self.options is None:
102-
return (
103-
"OrchestratorOptions is required for ComplexityRatingAnalysis. "
104-
"Please pass options to the constructor."
105-
)
106118
return None
107119

108120
@override
@@ -120,8 +132,7 @@ def compute(
120132
121133
Note:
122134
This method assumes ``validate_applicable_state`` has been called
123-
and returned None, ensuring ``experiment`` and ``self.options``
124-
are not None.
135+
and returned None, ensuring ``experiment`` is not None.
125136
126137
Args:
127138
experiment: The Ax Experiment to analyze. Must not be None.
@@ -135,16 +146,15 @@ def compute(
135146
with key experiment metrics.
136147
"""
137148
experiment = none_throws(experiment)
138-
options = none_throws(self.options)
139149
optimization_summary = summarize_ax_optimization_complexity(
140150
experiment=experiment,
141151
tier_metadata=self.tier_metadata,
142-
early_stopping_strategy=options.early_stopping_strategy,
143-
global_stopping_strategy=options.global_stopping_strategy,
144-
tolerated_trial_failure_rate=options.tolerated_trial_failure_rate,
145-
max_pending_trials=options.max_pending_trials,
152+
early_stopping_strategy=self.early_stopping_strategy,
153+
global_stopping_strategy=self.global_stopping_strategy,
154+
tolerated_trial_failure_rate=self.tolerated_trial_failure_rate,
155+
max_pending_trials=self.max_pending_trials,
146156
min_failed_trials_for_failure_rate_check=(
147-
options.min_failed_trials_for_failure_rate_check
157+
self.min_failed_trials_for_failure_rate_check
148158
),
149159
)
150160

ax/analysis/healthcheck/tests/test_complexity_rating.py

Lines changed: 77 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -20,48 +20,36 @@
2020
from ax.core.parameter_constraint import ParameterConstraint
2121
from ax.core.search_space import SearchSpace
2222
from ax.core.types import ComparisonOp
23-
from ax.service.orchestrator import OrchestratorOptions
2423
from ax.utils.common.testutils import TestCase
25-
from ax.utils.testing.core_stubs import get_branin_experiment
24+
from ax.utils.testing.core_stubs import (
25+
get_branin_experiment,
26+
get_improvement_global_stopping_strategy,
27+
get_percentile_early_stopping_strategy,
28+
)
2629

2730

2831
class TestComplexityRatingAnalysis(TestCase):
2932
def setUp(self) -> None:
3033
super().setUp()
3134
self.experiment = get_branin_experiment()
32-
self.options = OrchestratorOptions()
3335
self.tier_metadata: dict[str, object] = {
3436
"user_supplied_max_trials": 100,
3537
"uses_standard_api": True,
3638
}
3739

3840
def test_validate_applicable_state_requires_experiment(self) -> None:
39-
healthcheck = ComplexityRatingAnalysis(
40-
options=self.options, tier_metadata=self.tier_metadata
41-
)
41+
healthcheck = ComplexityRatingAnalysis(tier_metadata=self.tier_metadata)
4242
result = healthcheck.validate_applicable_state(experiment=None)
4343
self.assertIsNotNone(result)
4444
self.assertIn("Experiment is required", result)
4545

46-
def test_validate_applicable_state_requires_options(self) -> None:
47-
healthcheck = ComplexityRatingAnalysis(
48-
options=None, tier_metadata=self.tier_metadata
49-
)
50-
result = healthcheck.validate_applicable_state(experiment=self.experiment)
51-
self.assertIsNotNone(result)
52-
self.assertIn("OrchestratorOptions is required", result)
53-
5446
def test_validate_applicable_state_passes_with_valid_inputs(self) -> None:
55-
healthcheck = ComplexityRatingAnalysis(
56-
options=self.options, tier_metadata=self.tier_metadata
57-
)
47+
healthcheck = ComplexityRatingAnalysis(tier_metadata=self.tier_metadata)
5848
result = healthcheck.validate_applicable_state(experiment=self.experiment)
5949
self.assertIsNone(result)
6050

6151
def test_standard_configuration(self) -> None:
62-
healthcheck = ComplexityRatingAnalysis(
63-
options=self.options, tier_metadata=self.tier_metadata
64-
)
52+
healthcheck = ComplexityRatingAnalysis(tier_metadata=self.tier_metadata)
6553
card = healthcheck.compute(experiment=self.experiment)
6654

6755
self.assertEqual(card.name, "ComplexityRatingAnalysis")
@@ -90,7 +78,7 @@ def test_parameter_counts(self) -> None:
9078
]
9179
self.experiment._search_space = SearchSpace(parameters=params)
9280
card = ComplexityRatingAnalysis(
93-
options=self.options, tier_metadata=self.tier_metadata
81+
tier_metadata=self.tier_metadata
9482
).compute(experiment=self.experiment)
9583

9684
self.assertEqual(card.get_status(), expected_status)
@@ -115,7 +103,7 @@ def test_objectives_count(self) -> None:
115103
)
116104
)
117105
card = ComplexityRatingAnalysis(
118-
options=self.options, tier_metadata=self.tier_metadata
106+
tier_metadata=self.tier_metadata
119107
).compute(experiment=self.experiment)
120108

121109
self.assertEqual(card.get_status(), expected_status)
@@ -132,9 +120,9 @@ def test_constraints(self) -> None:
132120
for m in metrics
133121
],
134122
)
135-
card = ComplexityRatingAnalysis(
136-
options=self.options, tier_metadata=self.tier_metadata
137-
).compute(experiment=self.experiment)
123+
card = ComplexityRatingAnalysis(tier_metadata=self.tier_metadata).compute(
124+
experiment=self.experiment
125+
)
138126

139127
self.assertEqual(card.get_status(), HealthcheckStatus.WARNING)
140128
self.assertIn("Advanced", card.subtitle)
@@ -156,35 +144,34 @@ def test_constraints(self) -> None:
156144
self.experiment._search_space = SearchSpace(
157145
parameters=params, parameter_constraints=parameter_constraints
158146
)
159-
card = ComplexityRatingAnalysis(
160-
options=self.options, tier_metadata=self.tier_metadata
161-
).compute(experiment=self.experiment)
147+
card = ComplexityRatingAnalysis(tier_metadata=self.tier_metadata).compute(
148+
experiment=self.experiment
149+
)
162150

163151
self.assertEqual(card.get_status(), HealthcheckStatus.WARNING)
164152
self.assertIn("Advanced", card.subtitle)
165153
self.assertIn("3 parameter constraints", card.subtitle)
166154

167155
def test_stopping_strategies(self) -> None:
168-
test_cases = [
169-
("early_stopping", True, False, "Early stopping"),
170-
("global_stopping", False, True, "Global stopping"),
171-
]
156+
with self.subTest(strategy="early_stopping"):
157+
card = ComplexityRatingAnalysis(
158+
tier_metadata=self.tier_metadata,
159+
early_stopping_strategy=get_percentile_early_stopping_strategy(),
160+
).compute(experiment=self.experiment)
172161

173-
for name, uses_early, uses_global, expected_msg in test_cases:
174-
with self.subTest(strategy=name):
175-
options = OrchestratorOptions(
176-
# pyre-fixme[6]: Using a mock value for testing
177-
early_stopping_strategy="mock" if uses_early else None,
178-
# pyre-fixme[6]: Using a mock value for testing
179-
global_stopping_strategy="mock" if uses_global else None,
180-
)
181-
card = ComplexityRatingAnalysis(
182-
options=options, tier_metadata=self.tier_metadata
183-
).compute(experiment=self.experiment)
162+
self.assertEqual(card.get_status(), HealthcheckStatus.WARNING)
163+
self.assertIn("Advanced", card.subtitle)
164+
self.assertIn("Early stopping", card.subtitle)
184165

185-
self.assertEqual(card.get_status(), HealthcheckStatus.WARNING)
186-
self.assertIn("Advanced", card.subtitle)
187-
self.assertIn(expected_msg, card.subtitle)
166+
with self.subTest(strategy="global_stopping"):
167+
card = ComplexityRatingAnalysis(
168+
tier_metadata=self.tier_metadata,
169+
global_stopping_strategy=get_improvement_global_stopping_strategy(),
170+
).compute(experiment=self.experiment)
171+
172+
self.assertEqual(card.get_status(), HealthcheckStatus.WARNING)
173+
self.assertIn("Advanced", card.subtitle)
174+
self.assertIn("Global stopping", card.subtitle)
188175

189176
def test_trial_counts(self) -> None:
190177
test_cases = [
@@ -198,48 +185,47 @@ def test_trial_counts(self) -> None:
198185
"user_supplied_max_trials": max_trials,
199186
"uses_standard_api": True,
200187
}
201-
card = ComplexityRatingAnalysis(
202-
options=self.options, tier_metadata=tier_metadata
203-
).compute(experiment=self.experiment)
188+
card = ComplexityRatingAnalysis(tier_metadata=tier_metadata).compute(
189+
experiment=self.experiment
190+
)
204191

205192
self.assertEqual(card.get_status(), expected_status)
206193
self.assertIn(expected_tier, card.subtitle)
207194
self.assertIn(expected_msg, card.subtitle)
208195

209196
def test_unsupported_configurations(self) -> None:
210-
test_cases = [
211-
(
212-
"not_using_standard_api",
213-
OrchestratorOptions(),
214-
{"user_supplied_max_trials": 100, "uses_standard_api": False},
215-
"uses_standard_api=False",
216-
),
217-
(
218-
"high_failure_rate",
219-
OrchestratorOptions(tolerated_trial_failure_rate=0.95),
220-
{"user_supplied_max_trials": 100, "uses_standard_api": True},
221-
"0.95",
222-
),
223-
(
224-
"invalid_failure_rate_check",
225-
OrchestratorOptions(
226-
max_pending_trials=10,
227-
min_failed_trials_for_failure_rate_check=50,
228-
),
229-
{"user_supplied_max_trials": 100, "uses_standard_api": True},
230-
"min_failed_trials_for_failure_rate_check",
231-
),
232-
]
197+
with self.subTest(config="not_using_standard_api"):
198+
tier_metadata = {
199+
"user_supplied_max_trials": 100,
200+
"uses_standard_api": False,
201+
}
202+
card = ComplexityRatingAnalysis(tier_metadata=tier_metadata).compute(
203+
experiment=self.experiment
204+
)
205+
self.assertEqual(card.get_status(), HealthcheckStatus.FAIL)
206+
self.assertIn("Unsupported", card.subtitle)
207+
self.assertIn("uses_standard_api=False", card.subtitle)
233208

234-
for name, options, tier_metadata, expected_msg in test_cases:
235-
with self.subTest(config=name):
236-
card = ComplexityRatingAnalysis(
237-
options=options, tier_metadata=tier_metadata
238-
).compute(experiment=self.experiment)
209+
with self.subTest(config="high_failure_rate"):
210+
tier_metadata = {"user_supplied_max_trials": 100, "uses_standard_api": True}
211+
card = ComplexityRatingAnalysis(
212+
tier_metadata=tier_metadata,
213+
tolerated_trial_failure_rate=0.95,
214+
).compute(experiment=self.experiment)
215+
self.assertEqual(card.get_status(), HealthcheckStatus.FAIL)
216+
self.assertIn("Unsupported", card.subtitle)
217+
self.assertIn("0.95", card.subtitle)
239218

240-
self.assertEqual(card.get_status(), HealthcheckStatus.FAIL)
241-
self.assertIn("Unsupported", card.subtitle)
242-
self.assertIn(expected_msg, card.subtitle)
219+
with self.subTest(config="invalid_failure_rate_check"):
220+
tier_metadata = {"user_supplied_max_trials": 100, "uses_standard_api": True}
221+
card = ComplexityRatingAnalysis(
222+
tier_metadata=tier_metadata,
223+
max_pending_trials=10,
224+
min_failed_trials_for_failure_rate_check=50,
225+
).compute(experiment=self.experiment)
226+
self.assertEqual(card.get_status(), HealthcheckStatus.FAIL)
227+
self.assertIn("Unsupported", card.subtitle)
228+
self.assertIn("min_failed_trials_for_failure_rate_check", card.subtitle)
243229

244230
def test_unordered_choice_parameters(self) -> None:
245231
params = [
@@ -257,9 +243,9 @@ def test_unordered_choice_parameters(self) -> None:
257243

258244
self.assertTrue(is_unordered_choice(params[1], min_choices=3, max_choices=5))
259245

260-
card = ComplexityRatingAnalysis(
261-
options=self.options, tier_metadata=self.tier_metadata
262-
).compute(experiment=self.experiment)
246+
card = ComplexityRatingAnalysis(tier_metadata=self.tier_metadata).compute(
247+
experiment=self.experiment
248+
)
263249

264250
self.assertEqual(card.get_status(), HealthcheckStatus.WARNING)
265251
self.assertIn("Advanced", card.subtitle)
@@ -279,9 +265,9 @@ def test_binary_parameters_count(self) -> None:
279265
for p in params:
280266
self.assertTrue(can_map_to_binary(p))
281267

282-
card = ComplexityRatingAnalysis(
283-
options=self.options, tier_metadata=self.tier_metadata
284-
).compute(experiment=self.experiment)
268+
card = ComplexityRatingAnalysis(tier_metadata=self.tier_metadata).compute(
269+
experiment=self.experiment
270+
)
285271

286272
self.assertEqual(card.get_status(), HealthcheckStatus.WARNING)
287273
self.assertIn("Advanced", card.subtitle)
@@ -300,10 +286,9 @@ def test_multiple_violations(self) -> None:
300286
experiment = self.experiment
301287
experiment._search_space = SearchSpace(parameters=params)
302288
tier_metadata = {"user_supplied_max_trials": 300, "uses_standard_api": True}
303-
# pyre-ignore[6]: Using a mock value for testing
304-
options = OrchestratorOptions(early_stopping_strategy="mock")
305289
card = ComplexityRatingAnalysis(
306-
options=options, tier_metadata=tier_metadata
290+
tier_metadata=tier_metadata,
291+
early_stopping_strategy=get_percentile_early_stopping_strategy(),
307292
).compute(experiment=experiment)
308293

309294
self.assertEqual(card.get_status(), HealthcheckStatus.WARNING)
@@ -313,9 +298,9 @@ def test_multiple_violations(self) -> None:
313298
self.assertIn("Early stopping is enabled", card.subtitle)
314299

315300
def test_dataframe_summary(self) -> None:
316-
card = ComplexityRatingAnalysis(
317-
options=self.options, tier_metadata=self.tier_metadata
318-
).compute(experiment=self.experiment)
301+
card = ComplexityRatingAnalysis(tier_metadata=self.tier_metadata).compute(
302+
experiment=self.experiment
303+
)
319304

320305
df = card.df
321306
self.assertIsNotNone(df)

0 commit comments

Comments
 (0)