2020from ax .core .parameter_constraint import ParameterConstraint
2121from ax .core .search_space import SearchSpace
2222from ax .core .types import ComparisonOp
23- from ax .service .orchestrator import OrchestratorOptions
2423from ax .utils .common .testutils import TestCase
25- from ax .utils .testing .core_stubs import get_branin_experiment
24+ from ax .utils .testing .core_stubs import (
25+ get_branin_experiment ,
26+ get_improvement_global_stopping_strategy ,
27+ get_percentile_early_stopping_strategy ,
28+ )
2629
2730
2831class TestComplexityRatingAnalysis (TestCase ):
2932 def setUp (self ) -> None :
3033 super ().setUp ()
3134 self .experiment = get_branin_experiment ()
32- self .options = OrchestratorOptions ()
3335 self .tier_metadata : dict [str , object ] = {
3436 "user_supplied_max_trials" : 100 ,
3537 "uses_standard_api" : True ,
3638 }
3739
3840 def test_validate_applicable_state_requires_experiment (self ) -> None :
39- healthcheck = ComplexityRatingAnalysis (
40- options = self .options , tier_metadata = self .tier_metadata
41- )
41+ healthcheck = ComplexityRatingAnalysis (tier_metadata = self .tier_metadata )
4242 result = healthcheck .validate_applicable_state (experiment = None )
4343 self .assertIsNotNone (result )
4444 self .assertIn ("Experiment is required" , result )
4545
46- def test_validate_applicable_state_requires_options (self ) -> None :
47- healthcheck = ComplexityRatingAnalysis (
48- options = None , tier_metadata = self .tier_metadata
49- )
50- result = healthcheck .validate_applicable_state (experiment = self .experiment )
51- self .assertIsNotNone (result )
52- self .assertIn ("OrchestratorOptions is required" , result )
53-
5446 def test_validate_applicable_state_passes_with_valid_inputs (self ) -> None :
55- healthcheck = ComplexityRatingAnalysis (
56- options = self .options , tier_metadata = self .tier_metadata
57- )
47+ healthcheck = ComplexityRatingAnalysis (tier_metadata = self .tier_metadata )
5848 result = healthcheck .validate_applicable_state (experiment = self .experiment )
5949 self .assertIsNone (result )
6050
6151 def test_standard_configuration (self ) -> None :
62- healthcheck = ComplexityRatingAnalysis (
63- options = self .options , tier_metadata = self .tier_metadata
64- )
52+ healthcheck = ComplexityRatingAnalysis (tier_metadata = self .tier_metadata )
6553 card = healthcheck .compute (experiment = self .experiment )
6654
6755 self .assertEqual (card .name , "ComplexityRatingAnalysis" )
@@ -90,7 +78,7 @@ def test_parameter_counts(self) -> None:
9078 ]
9179 self .experiment ._search_space = SearchSpace (parameters = params )
9280 card = ComplexityRatingAnalysis (
93- options = self . options , tier_metadata = self .tier_metadata
81+ tier_metadata = self .tier_metadata
9482 ).compute (experiment = self .experiment )
9583
9684 self .assertEqual (card .get_status (), expected_status )
@@ -115,7 +103,7 @@ def test_objectives_count(self) -> None:
115103 )
116104 )
117105 card = ComplexityRatingAnalysis (
118- options = self . options , tier_metadata = self .tier_metadata
106+ tier_metadata = self .tier_metadata
119107 ).compute (experiment = self .experiment )
120108
121109 self .assertEqual (card .get_status (), expected_status )
@@ -132,9 +120,9 @@ def test_constraints(self) -> None:
132120 for m in metrics
133121 ],
134122 )
135- card = ComplexityRatingAnalysis (
136- options = self .options , tier_metadata = self . tier_metadata
137- ). compute ( experiment = self . experiment )
123+ card = ComplexityRatingAnalysis (tier_metadata = self . tier_metadata ). compute (
124+ experiment = self .experiment
125+ )
138126
139127 self .assertEqual (card .get_status (), HealthcheckStatus .WARNING )
140128 self .assertIn ("Advanced" , card .subtitle )
@@ -156,35 +144,34 @@ def test_constraints(self) -> None:
156144 self .experiment ._search_space = SearchSpace (
157145 parameters = params , parameter_constraints = parameter_constraints
158146 )
159- card = ComplexityRatingAnalysis (
160- options = self .options , tier_metadata = self . tier_metadata
161- ). compute ( experiment = self . experiment )
147+ card = ComplexityRatingAnalysis (tier_metadata = self . tier_metadata ). compute (
148+ experiment = self .experiment
149+ )
162150
163151 self .assertEqual (card .get_status (), HealthcheckStatus .WARNING )
164152 self .assertIn ("Advanced" , card .subtitle )
165153 self .assertIn ("3 parameter constraints" , card .subtitle )
166154
167155 def test_stopping_strategies (self ) -> None :
168- test_cases = [
169- ("early_stopping" , True , False , "Early stopping" ),
170- ("global_stopping" , False , True , "Global stopping" ),
171- ]
156+ with self .subTest (strategy = "early_stopping" ):
157+ card = ComplexityRatingAnalysis (
158+ tier_metadata = self .tier_metadata ,
159+ early_stopping_strategy = get_percentile_early_stopping_strategy (),
160+ ).compute (experiment = self .experiment )
172161
173- for name , uses_early , uses_global , expected_msg in test_cases :
174- with self .subTest (strategy = name ):
175- options = OrchestratorOptions (
176- # pyre-fixme[6]: Using a mock value for testing
177- early_stopping_strategy = "mock" if uses_early else None ,
178- # pyre-fixme[6]: Using a mock value for testing
179- global_stopping_strategy = "mock" if uses_global else None ,
180- )
181- card = ComplexityRatingAnalysis (
182- options = options , tier_metadata = self .tier_metadata
183- ).compute (experiment = self .experiment )
162+ self .assertEqual (card .get_status (), HealthcheckStatus .WARNING )
163+ self .assertIn ("Advanced" , card .subtitle )
164+ self .assertIn ("Early stopping" , card .subtitle )
184165
185- self .assertEqual (card .get_status (), HealthcheckStatus .WARNING )
186- self .assertIn ("Advanced" , card .subtitle )
187- self .assertIn (expected_msg , card .subtitle )
166+ with self .subTest (strategy = "global_stopping" ):
167+ card = ComplexityRatingAnalysis (
168+ tier_metadata = self .tier_metadata ,
169+ global_stopping_strategy = get_improvement_global_stopping_strategy (),
170+ ).compute (experiment = self .experiment )
171+
172+ self .assertEqual (card .get_status (), HealthcheckStatus .WARNING )
173+ self .assertIn ("Advanced" , card .subtitle )
174+ self .assertIn ("Global stopping" , card .subtitle )
188175
189176 def test_trial_counts (self ) -> None :
190177 test_cases = [
@@ -198,48 +185,47 @@ def test_trial_counts(self) -> None:
198185 "user_supplied_max_trials" : max_trials ,
199186 "uses_standard_api" : True ,
200187 }
201- card = ComplexityRatingAnalysis (
202- options = self .options , tier_metadata = tier_metadata
203- ). compute ( experiment = self . experiment )
188+ card = ComplexityRatingAnalysis (tier_metadata = tier_metadata ). compute (
189+ experiment = self .experiment
190+ )
204191
205192 self .assertEqual (card .get_status (), expected_status )
206193 self .assertIn (expected_tier , card .subtitle )
207194 self .assertIn (expected_msg , card .subtitle )
208195
209196 def test_unsupported_configurations (self ) -> None :
210- test_cases = [
211- (
212- "not_using_standard_api" ,
213- OrchestratorOptions (),
214- {"user_supplied_max_trials" : 100 , "uses_standard_api" : False },
215- "uses_standard_api=False" ,
216- ),
217- (
218- "high_failure_rate" ,
219- OrchestratorOptions (tolerated_trial_failure_rate = 0.95 ),
220- {"user_supplied_max_trials" : 100 , "uses_standard_api" : True },
221- "0.95" ,
222- ),
223- (
224- "invalid_failure_rate_check" ,
225- OrchestratorOptions (
226- max_pending_trials = 10 ,
227- min_failed_trials_for_failure_rate_check = 50 ,
228- ),
229- {"user_supplied_max_trials" : 100 , "uses_standard_api" : True },
230- "min_failed_trials_for_failure_rate_check" ,
231- ),
232- ]
197+ with self .subTest (config = "not_using_standard_api" ):
198+ tier_metadata = {
199+ "user_supplied_max_trials" : 100 ,
200+ "uses_standard_api" : False ,
201+ }
202+ card = ComplexityRatingAnalysis (tier_metadata = tier_metadata ).compute (
203+ experiment = self .experiment
204+ )
205+ self .assertEqual (card .get_status (), HealthcheckStatus .FAIL )
206+ self .assertIn ("Unsupported" , card .subtitle )
207+ self .assertIn ("uses_standard_api=False" , card .subtitle )
233208
234- for name , options , tier_metadata , expected_msg in test_cases :
235- with self .subTest (config = name ):
236- card = ComplexityRatingAnalysis (
237- options = options , tier_metadata = tier_metadata
238- ).compute (experiment = self .experiment )
209+ with self .subTest (config = "high_failure_rate" ):
210+ tier_metadata = {"user_supplied_max_trials" : 100 , "uses_standard_api" : True }
211+ card = ComplexityRatingAnalysis (
212+ tier_metadata = tier_metadata ,
213+ tolerated_trial_failure_rate = 0.95 ,
214+ ).compute (experiment = self .experiment )
215+ self .assertEqual (card .get_status (), HealthcheckStatus .FAIL )
216+ self .assertIn ("Unsupported" , card .subtitle )
217+ self .assertIn ("0.95" , card .subtitle )
239218
240- self .assertEqual (card .get_status (), HealthcheckStatus .FAIL )
241- self .assertIn ("Unsupported" , card .subtitle )
242- self .assertIn (expected_msg , card .subtitle )
219+ with self .subTest (config = "invalid_failure_rate_check" ):
220+ tier_metadata = {"user_supplied_max_trials" : 100 , "uses_standard_api" : True }
221+ card = ComplexityRatingAnalysis (
222+ tier_metadata = tier_metadata ,
223+ max_pending_trials = 10 ,
224+ min_failed_trials_for_failure_rate_check = 50 ,
225+ ).compute (experiment = self .experiment )
226+ self .assertEqual (card .get_status (), HealthcheckStatus .FAIL )
227+ self .assertIn ("Unsupported" , card .subtitle )
228+ self .assertIn ("min_failed_trials_for_failure_rate_check" , card .subtitle )
243229
244230 def test_unordered_choice_parameters (self ) -> None :
245231 params = [
@@ -257,9 +243,9 @@ def test_unordered_choice_parameters(self) -> None:
257243
258244 self .assertTrue (is_unordered_choice (params [1 ], min_choices = 3 , max_choices = 5 ))
259245
260- card = ComplexityRatingAnalysis (
261- options = self .options , tier_metadata = self . tier_metadata
262- ). compute ( experiment = self . experiment )
246+ card = ComplexityRatingAnalysis (tier_metadata = self . tier_metadata ). compute (
247+ experiment = self .experiment
248+ )
263249
264250 self .assertEqual (card .get_status (), HealthcheckStatus .WARNING )
265251 self .assertIn ("Advanced" , card .subtitle )
@@ -279,9 +265,9 @@ def test_binary_parameters_count(self) -> None:
279265 for p in params :
280266 self .assertTrue (can_map_to_binary (p ))
281267
282- card = ComplexityRatingAnalysis (
283- options = self .options , tier_metadata = self . tier_metadata
284- ). compute ( experiment = self . experiment )
268+ card = ComplexityRatingAnalysis (tier_metadata = self . tier_metadata ). compute (
269+ experiment = self .experiment
270+ )
285271
286272 self .assertEqual (card .get_status (), HealthcheckStatus .WARNING )
287273 self .assertIn ("Advanced" , card .subtitle )
@@ -300,10 +286,9 @@ def test_multiple_violations(self) -> None:
300286 experiment = self .experiment
301287 experiment ._search_space = SearchSpace (parameters = params )
302288 tier_metadata = {"user_supplied_max_trials" : 300 , "uses_standard_api" : True }
303- # pyre-ignore[6]: Using a mock value for testing
304- options = OrchestratorOptions (early_stopping_strategy = "mock" )
305289 card = ComplexityRatingAnalysis (
306- options = options , tier_metadata = tier_metadata
290+ tier_metadata = tier_metadata ,
291+ early_stopping_strategy = get_percentile_early_stopping_strategy (),
307292 ).compute (experiment = experiment )
308293
309294 self .assertEqual (card .get_status (), HealthcheckStatus .WARNING )
@@ -313,9 +298,9 @@ def test_multiple_violations(self) -> None:
313298 self .assertIn ("Early stopping is enabled" , card .subtitle )
314299
315300 def test_dataframe_summary (self ) -> None :
316- card = ComplexityRatingAnalysis (
317- options = self .options , tier_metadata = self . tier_metadata
318- ). compute ( experiment = self . experiment )
301+ card = ComplexityRatingAnalysis (tier_metadata = self . tier_metadata ). compute (
302+ experiment = self .experiment
303+ )
319304
320305 df = card .df
321306 self .assertIsNotNone (df )
0 commit comments