Skip to content

Commit 9bdca37

Browse files
authored
Add slow mode to quick search (#510)
* Add slow mode support * next door -> adjacent * remove redundant comment * update some sorting code
1 parent cdb5f77 commit 9bdca37

File tree

4 files changed

+387
-58
lines changed

4 files changed

+387
-58
lines changed

model_analyzer/config/generate/coordinate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from copy import copy
15+
from copy import deepcopy
1616

1717

1818
class Coordinate:
@@ -25,7 +25,7 @@ def __init__(self, val):
2525
val: list
2626
List of floats or integers cooresponding to the location in space
2727
"""
28-
self._values = copy(val)
28+
self._values = deepcopy(val)
2929

3030
def __getitem__(self, idx):
3131
return self._values[idx]

model_analyzer/config/generate/neighborhood.py

Lines changed: 125 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ def __init__(self, neighborhood_config: NeighborhoodConfig,
4949
self._radius = self._config.get_radius()
5050
self._neighborhood = self._create_neighborhood()
5151

52+
self._force_slow_mode = False
53+
5254
@property
5355
def coordinate_data(self):
5456
return self._coordinate_data
@@ -71,31 +73,75 @@ def enough_coordinates_initialized(self) -> bool:
7173
"""
7274
Returns true if enough coordinates inside of the neighborhood
7375
have been initialized. Else false
76+
77+
If the neighborhood is in slow mode, this means all adjacent neighbors
78+
must be visited
79+
"""
80+
if self._is_slow_mode():
81+
return self._are_all_adjacent_neighbors_visited()
82+
else:
83+
min_initialized = self._config.get_min_initialized()
84+
num_initialized = len(self._get_initialized_coordinates())
85+
return num_initialized >= min_initialized
86+
87+
def force_slow_mode(self):
7488
"""
75-
min_initialized = self._config.get_min_initialized()
76-
num_initialized = len(self._get_initialized_coordinates())
77-
return num_initialized >= min_initialized
89+
When called, forces the neighborhood into slow mode
90+
"""
91+
self._force_slow_mode = True
7892

7993
def calculate_new_coordinate(self,
8094
magnitude: int,
8195
enable_clipping: bool = True,
8296
clip_value: int = 2) -> Coordinate:
8397
"""
8498
Based on the measurements in the neighborhood, determine where
85-
the next location should be
99+
the next location should be.
100+
101+
If the neighborhood is in slow mode, return the best found measurement
102+
Otherwise calculate a new coordinate from the measurements
86103
87104
Parameters
88105
----------
89106
magnitude
90107
How large of a step to take
91-
disable_clipping
108+
enable_clipping
92109
Determines whether or not to clip the final step vector.
110+
clip_value
111+
What value to clip the vector at, if it is enabled
93112
94113
Returns
95114
-------
96115
new_coordinate
97116
The new coordinate computed based on the neighborhood measurements.
98117
"""
118+
119+
if self._is_slow_mode():
120+
return self._get_best_coordinate_found()
121+
else:
122+
return self._calculate_new_coordinate(magnitude, enable_clipping,
123+
clip_value)
124+
125+
def _get_best_coordinate_found(self) -> Coordinate:
126+
vectors, measurements = self._get_measurements_passing_constraints()
127+
128+
if len(vectors) == 0:
129+
return self._home_coordinate
130+
131+
home_measurement = self._get_home_measurement()
132+
133+
if home_measurement.is_passing_constraints():
134+
vectors.append(Coordinate([0] * self._config.get_num_dimensions()))
135+
measurements.append(home_measurement)
136+
137+
_, best_vector = sorted(zip(measurements, vectors))[-1]
138+
139+
best_coordinate = self._home_coordinate + best_vector
140+
return best_coordinate
141+
142+
def _calculate_new_coordinate(self, magnitude, enable_clipping,
143+
clip_value) -> Coordinate:
144+
99145
step_vector = self._get_step_vector() * magnitude
100146

101147
if enable_clipping:
@@ -133,14 +179,30 @@ def _clip_vector_values(self, vector: Coordinate,
133179

134180
if max_value > clip_value and max_value != 0:
135181
for i in range(len(vector)):
136-
vector[i] = clip_value * vector[i]/max_value
182+
vector[i] = clip_value * vector[i] / max_value
137183
return vector
138184

139185
def pick_coordinate_to_initialize(self) -> Optional[Coordinate]:
140186
"""
141187
Based on the initialized coordinate values, pick an unvisited
142188
coordinate to initialize next.
189+
190+
If the neighborhood is in slow mode, only pick from within the adjacent neighbors
143191
"""
192+
193+
if self._is_slow_mode():
194+
return self._pick_slow_mode_coordinate_to_initialize()
195+
else:
196+
return self._pick_fast_mode_coordinate_to_initialize()
197+
198+
def _pick_slow_mode_coordinate_to_initialize(self):
199+
for neighbor in self._get_all_adjacent_neighbors():
200+
if not self._is_coordinate_visited(neighbor):
201+
return neighbor
202+
203+
raise Exception("Picking slow mode coordinate, but none are unvisited")
204+
205+
def _pick_fast_mode_coordinate_to_initialize(self):
144206
covered_values_per_dimension = self._get_covered_values_per_dimension()
145207

146208
max_num_uncovered = -1
@@ -250,12 +312,11 @@ def _get_step_vector(self) -> Coordinate:
250312
step_vector
251313
a coordinate that tells the direction to move.
252314
"""
253-
home_measurement = self._coordinate_data.get_measurement(
254-
coordinate=self._home_coordinate)
315+
home_measurement = self._get_home_measurement()
255316

256317
assert home_measurement is not None, "Home measurement cannot be NoneType."
257318

258-
if home_measurement.is_passing_constraints():
319+
if self._is_home_passing_constraints():
259320
return self._optimize_for_objectives(home_measurement)
260321

261322
return self._optimize_for_constraints(home_measurement)
@@ -428,3 +489,58 @@ def _get_num_uncovered_values(
428489
num_uncovered += 1
429490

430491
return num_uncovered
492+
493+
def _is_slow_mode(self):
494+
if self._force_slow_mode:
495+
return True
496+
497+
if not self._is_home_measured():
498+
return False
499+
500+
passing_vectors, _ = self._get_measurements_passing_constraints()
501+
all_vectors, _ = self._get_all_visited_measurements()
502+
503+
any_failing = len(all_vectors) != len(passing_vectors)
504+
any_passing = len(passing_vectors) != 0
505+
home_passing = self._is_home_passing_constraints()
506+
507+
return (home_passing and any_failing) or (not home_passing and
508+
any_passing)
509+
510+
def _are_all_adjacent_neighbors_visited(self):
511+
for neighbor in self._get_all_adjacent_neighbors():
512+
if not self._is_coordinate_visited(neighbor):
513+
return False
514+
return True
515+
516+
def _get_all_adjacent_neighbors(self):
517+
adjacent_neighbors = []
518+
519+
for dim in range(self._config.get_num_dimensions()):
520+
dimension = self._config.get_dimension(dim)
521+
522+
down_neighbor = Coordinate(self._home_coordinate)
523+
down_neighbor[dim] -= 1
524+
if down_neighbor[dim] >= dimension.get_min_idx():
525+
adjacent_neighbors.append(down_neighbor)
526+
527+
up_neighbor = Coordinate(self._home_coordinate)
528+
up_neighbor[dim] += 1
529+
if up_neighbor[dim] <= dimension.get_max_idx():
530+
adjacent_neighbors.append(up_neighbor)
531+
532+
return adjacent_neighbors
533+
534+
def _get_home_measurement(self):
535+
return self._coordinate_data.get_measurement(
536+
coordinate=self._home_coordinate)
537+
538+
def _is_home_measured(self):
539+
return self._get_home_measurement() is not None
540+
541+
def _is_home_passing_constraints(self):
542+
if not self._is_home_measured():
543+
raise Exception("Can't check home passing if it isn't measured yet")
544+
545+
home_measurement = self._get_home_measurement()
546+
return home_measurement.is_passing_constraints()

model_analyzer/config/generate/quick_run_config_generator.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ def __init__(self, search_config: SearchConfig,
9595
self._search_config.get_neighborhood_config(),
9696
self._home_coordinate)
9797

98+
# Sticky bit. Once true, we should never stay at a home that is failing or None
99+
self._home_has_passed = False
100+
98101
self._done = False
99102

100103
def _is_done(self) -> bool:
@@ -121,8 +124,7 @@ def _step(self):
121124
Determine self._coordinate_to_measure, which is what is used to
122125
create the next RunConfig
123126
"""
124-
if self._measuring_home_coordinate(
125-
) and self._get_last_results() is None:
127+
if self._should_step_back():
126128
self._take_step_back()
127129
elif self._neighborhood.enough_coordinates_initialized():
128130
self._take_step()
@@ -153,6 +155,10 @@ def set_last_results(self, measurements: List[Union[RunConfigMeasurement,
153155
if measurements[0] is not None:
154156
self._update_best_measurement(measurement=measurements[0])
155157

158+
if self._measuring_home_coordinate(
159+
) and measurements[0].is_passing_constraints():
160+
self._home_has_passed = True
161+
156162
self._print_debug_logs(measurements)
157163

158164
def _update_best_measurement(self, measurement: RunConfigMeasurement):
@@ -197,20 +203,38 @@ def _take_step(self):
197203
logger.debug(f"Stepping {self._home_coordinate}->{new_coordinate}")
198204
self._home_coordinate = new_coordinate
199205
self._coordinate_to_measure = new_coordinate
200-
self._recreate_neighborhood()
206+
self._recreate_neighborhood(force_slow_mode=False)
201207

202208
def _take_step_back(self):
203209
new_coordinate = self._neighborhood.get_nearest_neighbor(
204210
coordinate_in=self._best_coordinate)
205211

212+
# TODO: TMA-871: handle back-off (and its termination) better.
213+
if new_coordinate == self._home_coordinate:
214+
self._done = True
215+
206216
logger.debug(
207217
f"Stepping back: {self._home_coordinate}->{new_coordinate}")
208218
self._home_coordinate = new_coordinate
209219
self._coordinate_to_measure = new_coordinate
210-
self._recreate_neighborhood()
220+
self._recreate_neighborhood(force_slow_mode=True)
211221

212222
self._magnitude_scaler *= MAGNITUDE_DECAY_RATE
213223

224+
def _should_step_back(self):
225+
"""
226+
Step back if take any of the following steps:
227+
- Step from a passing home to a failing home
228+
- Step from any home to home with a None measurement
229+
"""
230+
if self._measuring_home_coordinate():
231+
if self._get_last_results() is None:
232+
return True
233+
last_results_passed = self._get_last_results(
234+
).is_passing_constraints()
235+
if not last_results_passed and self._home_has_passed:
236+
return True
237+
214238
def _measuring_home_coordinate(self):
215239
return self._coordinate_to_measure == self._home_coordinate
216240

@@ -224,11 +248,13 @@ def _determine_if_done(self, new_coordinate: Coordinate):
224248
if self._coordinate_data.get_visit_count(new_coordinate) >= 2:
225249
self._done = True
226250

227-
def _recreate_neighborhood(self):
251+
def _recreate_neighborhood(self, force_slow_mode: bool):
228252
neighborhood_config = self._search_config.get_neighborhood_config()
229253

230254
self._neighborhood = Neighborhood(neighborhood_config,
231255
self._home_coordinate)
256+
if force_slow_mode:
257+
self._neighborhood.force_slow_mode()
232258

233259
def _pick_coordinate_to_initialize(self):
234260
self._coordinate_to_measure = self._neighborhood.pick_coordinate_to_initialize(

0 commit comments

Comments
 (0)