2424from model_analyzer .config .input .config_command_profile import ConfigCommandProfile
2525from model_analyzer .config .input .config_command_report import ConfigCommandReport
2626
27+
2728class ConstraintManager :
2829 """
2930 Handles processing and applying
@@ -34,20 +35,21 @@ class ConstraintManager:
3435 config: ConfigCommandProfile or ConfigCommandReport
3536 """
3637
37- def __init__ (self , config : Union [ConfigCommandProfile , ConfigCommandReport ]) -> None :
38+ def __init__ (
39+ self , config : Union [ConfigCommandProfile ,
40+ ConfigCommandReport ]) -> None :
3841 self ._constraints = {}
3942
4043 if config :
4144 # Model constraints
4245 if "profile_models" in config .get_config ():
4346 for model in config .profile_models :
44- self ._constraints [model .model_name ()
45- ] = model .constraints ()
47+ self ._constraints [model .model_name ()] = model .constraints ()
4648
4749 # Global constraints
4850 if "constraints" in config .get_all_config ():
49- self ._constraints [GLOBAL_CONSTRAINTS_KEY ] = ModelConstraints (config . get_all_config ()[
50- "constraints" ])
51+ self ._constraints [GLOBAL_CONSTRAINTS_KEY ] = ModelConstraints (
52+ config . get_all_config ()[ "constraints" ])
5153
5254 def get_constraints_for_all_models (self ):
5355 """
@@ -59,8 +61,8 @@ def get_constraints_for_all_models(self):
5961
6062 return self ._constraints
6163
62- def satisfies_constraints (self ,
63- run_config_measurement : 'RunConfigMeasurement' ) -> bool :
64+ def satisfies_constraints (
65+ self , run_config_measurement : 'RunConfigMeasurement' ) -> bool :
6466 """
6567 Checks that the measurements, for every model, satisfy
6668 the provided list of constraints
@@ -77,18 +79,20 @@ def satisfies_constraints(self,
7779 """
7880
7981 if self ._constraints :
80- for (model_name , model_metrics ) in run_config_measurement .data ().items ():
82+ for (model_name ,
83+ model_metrics ) in run_config_measurement .data ().items ():
8184 for metric in model_metrics :
8285 if self ._metric_matches_constraint (
8386 metric , self ._constraints [model_name ]):
8487 if self ._get_failure_percentage (
85- metric , self ._constraints [model_name ][metric .tag ]) > 0 :
88+ metric ,
89+ self ._constraints [model_name ][metric .tag ]) > 0 :
8690 return False
8791
8892 return True
8993
90- def constraint_failure_percentage (self ,
91- run_config_measurement : 'RunConfigMeasurement' ) -> float :
94+ def constraint_failure_percentage (
95+ self , run_config_measurement : 'RunConfigMeasurement' ) -> float :
9296 """
9397 Additive percentage, for every measurement, in every model, of how much
9498 the RCM is failing the constraints by
@@ -100,7 +104,8 @@ def constraint_failure_percentage(self,
100104 failure_percentage : float = 0
101105
102106 if self ._constraints :
103- for (model_name , model_metrics ) in run_config_measurement .data ().items ():
107+ for (model_name ,
108+ model_metrics ) in run_config_measurement .data ().items ():
104109 for metric in model_metrics :
105110 if self ._metric_matches_constraint (
106111 metric , self ._constraints [model_name ]):
@@ -109,15 +114,15 @@ def constraint_failure_percentage(self,
109114
110115 return failure_percentage * 100
111116
112- def _metric_matches_constraint (self ,
113- metric : Record , constraint : ModelConstraints ) -> bool :
117+ def _metric_matches_constraint (self , metric : Record ,
118+ constraint : ModelConstraints ) -> bool :
114119 if constraint .has_metric (metric .tag ):
115120 return True
116121 else :
117122 return False
118123
119- def _get_failure_percentage (self , metric : Record , constraint : Dict [ str ,
120- int ]) -> float :
124+ def _get_failure_percentage (self , metric : Record ,
125+ constraint : Dict [ str , int ]) -> float :
121126
122127 failure_percentage = 0
123128
0 commit comments