Skip to content

Commit 30cbbfe

Browse files
authored
only use one coordinate data (#512)
* only use one coordinate data * improve/fix tests * minor tweak
1 parent 5626710 commit 30cbbfe

File tree

5 files changed

+201
-209
lines changed

5 files changed

+201
-209
lines changed

model_analyzer/config/generate/coordinate_data.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class CoordinateData:
2626
def __init__(self):
2727
self._measurements = {}
2828
self._visit_counts = {}
29+
self._is_measured = {}
2930

3031
def get_measurement(
3132
self, coordinate: Coordinate) -> Optional[RunConfigMeasurement]:
@@ -42,6 +43,20 @@ def set_measurement(self, coordinate: Coordinate,
4243
"""
4344
key: Tuple[Coordinate, ...] = tuple(coordinate)
4445
self._measurements[key] = measurement
46+
self._is_measured[key] = True
47+
48+
def is_measured(self, coordinate: Coordinate) -> bool:
49+
"""
50+
Returns true if a measurement has been set for the given Coordinate
51+
"""
52+
key: Tuple[Coordinate, ...] = tuple(coordinate)
53+
return self._is_measured.get(key, False)
54+
55+
def has_valid_measurement(self, coordinate: Coordinate) -> bool:
56+
"""
57+
Returns true if there is a valid measurement for the given Coordinate
58+
"""
59+
return self.get_measurement(coordinate) is not None
4560

4661
def reset_measurements(self):
4762
"""

model_analyzer/config/generate/neighborhood.py

Lines changed: 25 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class Neighborhood:
4242
TRANSLATION_LIST = [0.09, 0.3, 1.0]
4343

4444
def __init__(self, neighborhood_config: NeighborhoodConfig,
45-
home_coordinate: Coordinate):
45+
home_coordinate: Coordinate, coordinate_data: CoordinateData):
4646
"""
4747
Parameters
4848
----------
@@ -55,17 +55,13 @@ def __init__(self, neighborhood_config: NeighborhoodConfig,
5555

5656
self._config = neighborhood_config
5757
self._home_coordinate = home_coordinate
58-
self._coordinate_data = CoordinateData()
58+
self._coordinate_data = coordinate_data
5959

6060
self._radius = self._config.get_radius()
6161
self._neighborhood = self._create_neighborhood()
6262

6363
self._force_slow_mode = False
6464

65-
@property
66-
def coordinate_data(self):
67-
return self._coordinate_data
68-
6965
@classmethod
7066
def calc_distance(cls, coordinate1: Coordinate,
7167
coordinate2: Coordinate) -> float:
@@ -83,16 +79,17 @@ def calc_distance(cls, coordinate1: Coordinate,
8379
def enough_coordinates_initialized(self) -> bool:
8480
"""
8581
Returns true if enough coordinates inside of the neighborhood
86-
have been initialized. Else false
82+
have been initialized with valid measurements. Else false
8783
8884
If the neighborhood is in slow mode, this means all adjacent neighbors
8985
must be visited
9086
"""
9187
if self._is_slow_mode():
92-
return self._are_all_adjacent_neighbors_visited()
88+
return self._are_all_adjacent_neighbors_measured()
9389
else:
9490
min_initialized = self._config.get_min_initialized()
95-
num_initialized = len(self._get_initialized_coordinates())
91+
num_initialized = len(
92+
self._get_coordinates_with_valid_measurements())
9693
return num_initialized >= min_initialized
9794

9895
def force_slow_mode(self):
@@ -178,7 +175,7 @@ def pick_coordinate_to_initialize(self) -> Optional[Coordinate]:
178175

179176
def _pick_slow_mode_coordinate_to_initialize(self):
180177
for neighbor in self._get_all_adjacent_neighbors():
181-
if not self._is_coordinate_visited(neighbor):
178+
if not self._is_coordinate_measured(neighbor):
182179
return neighbor
183180

184181
raise Exception("Picking slow mode coordinate, but none are unvisited")
@@ -189,7 +186,7 @@ def _pick_fast_mode_coordinate_to_initialize(self):
189186
max_num_uncovered = -1
190187
best_coordinate = None
191188
for coordinate in self._neighborhood:
192-
if not self._is_coordinate_visited(coordinate):
189+
if not self._is_coordinate_measured(coordinate):
193190
num_uncovered = self._get_num_uncovered_values(
194191
coordinate, covered_values_per_dimension)
195192

@@ -257,23 +254,11 @@ def _enumerate_all_values_in_bounds(
257254
tuples = list(product(*possible_index_values))
258255
return [list(x) for x in tuples]
259256

260-
def _get_visited_coordinates(self) -> List[Coordinate]:
261-
"""
262-
Returns the list of coordinates in the neighborhood that have been
263-
visited (except the home coordinate).
264-
"""
265-
visited_coordinates = []
266-
for coordinate in self._neighborhood:
267-
if coordinate != self._home_coordinate \
268-
and self._is_coordinate_visited(coordinate):
269-
visited_coordinates.append(deepcopy(coordinate))
270-
return visited_coordinates
271-
272-
def _get_initialized_coordinates(self) -> List[Coordinate]:
257+
def _get_coordinates_with_valid_measurements(self) -> List[Coordinate]:
273258
initialized_coordinates = []
274259
for coordinate in self._neighborhood:
275-
if coordinate != self._home_coordinate \
276-
and self._is_coordinate_initialized(coordinate):
260+
if coordinate != self._home_coordinate and self._coordinate_data.has_valid_measurement(
261+
coordinate):
277262
initialized_coordinates.append(deepcopy(coordinate))
278263
return initialized_coordinates
279264

@@ -296,7 +281,7 @@ def _calculate_step_vector_from_measurements(
296281
self, compare_constraints: bool) -> Coordinate:
297282

298283
home_measurement = self._get_home_measurement()
299-
vectors, measurements = self._get_all_visited_measurements()
284+
vectors, measurements = self._get_all_measurements()
300285

301286
# This function should only ever be called if all are passing or none are passing
302287
_, p = self._get_measurements_passing_constraints()
@@ -340,7 +325,7 @@ def _calculate_step_vector_from_vectors_and_weights(self, vectors, weights):
340325

341326
return step_vector
342327

343-
def _get_all_visited_measurements(
328+
def _get_all_measurements(
344329
self) -> Tuple[List[Coordinate], List[RunConfigMeasurement]]:
345330
"""
346331
Gather all the visited vectors (directions from the home coordinate)
@@ -351,10 +336,11 @@ def _get_all_visited_measurements(
351336
(vectors, measurements)
352337
collection of vectors and their measurements.
353338
"""
354-
visited_coordinates = self._get_visited_coordinates()
339+
coordinates = self._get_coordinates_with_valid_measurements()
340+
355341
vectors = []
356342
measurements = []
357-
for coordinate in visited_coordinates:
343+
for coordinate in coordinates:
358344
measurement = self._coordinate_data.get_measurement(coordinate)
359345
if measurement:
360346
vectors.append(coordinate - self._home_coordinate)
@@ -372,22 +358,19 @@ def _get_measurements_passing_constraints(
372358
(vectors, measurements)
373359
collection of vectors and their measurements.
374360
"""
375-
visited_coordinates = self._get_visited_coordinates()
361+
coordinates = self._get_coordinates_with_valid_measurements()
376362

377363
vectors = []
378364
measurements = []
379-
for coordinate in visited_coordinates:
365+
for coordinate in coordinates:
380366
measurement = self._coordinate_data.get_measurement(coordinate)
381367
if measurement and measurement.is_passing_constraints():
382368
vectors.append(coordinate - self._home_coordinate)
383369
measurements.append(measurement)
384370
return vectors, measurements
385371

386-
def _is_coordinate_visited(self, coordinate: Coordinate) -> bool:
387-
return self._coordinate_data.get_visit_count(coordinate) > 0
388-
389-
def _is_coordinate_initialized(self, coordinate: Coordinate) -> bool:
390-
return self._coordinate_data.get_measurement(coordinate) is not None
372+
def _is_coordinate_measured(self, coordinate: Coordinate) -> bool:
373+
return self._coordinate_data.is_measured(coordinate)
391374

392375
def _clamp_coordinate_to_bounds(self, coordinate: Coordinate) -> Coordinate:
393376

@@ -409,13 +392,13 @@ def _get_covered_values_per_dimension(self) -> List[Dict[Coordinate, bool]]:
409392
(e.g.)
410393
covered_values_per_dimension[dimension][value] = bool
411394
"""
412-
visited_coordinates = self._get_visited_coordinates()
395+
measured_coordinates = self._get_coordinates_with_valid_measurements()
413396

414397
covered_values_per_dimension: List[Dict[Coordinate, bool]] = [
415398
{} for _ in range(self._config.get_num_dimensions())
416399
]
417400

418-
for coordinate in visited_coordinates:
401+
for coordinate in measured_coordinates:
419402
for i, v in enumerate(coordinate):
420403
covered_values_per_dimension[i][v] = True
421404

@@ -444,7 +427,7 @@ def _is_slow_mode(self):
444427
return False
445428

446429
passing_vectors, _ = self._get_measurements_passing_constraints()
447-
all_vectors, _ = self._get_all_visited_measurements()
430+
all_vectors, _ = self._get_all_measurements()
448431

449432
any_failing = len(all_vectors) != len(passing_vectors)
450433
any_passing = len(passing_vectors) != 0
@@ -453,9 +436,9 @@ def _is_slow_mode(self):
453436
return (home_passing and any_failing) or (not home_passing and
454437
any_passing)
455438

456-
def _are_all_adjacent_neighbors_visited(self):
439+
def _are_all_adjacent_neighbors_measured(self):
457440
for neighbor in self._get_all_adjacent_neighbors():
458-
if not self._is_coordinate_visited(neighbor):
441+
if not self._is_coordinate_measured(neighbor):
459442
return False
460443
return True
461444

model_analyzer/config/generate/quick_run_config_generator.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def __init__(self, search_config: SearchConfig,
9090

9191
self._neighborhood = Neighborhood(
9292
self._search_config.get_neighborhood_config(),
93-
self._home_coordinate)
93+
self._home_coordinate, self._coordinate_data)
9494

9595
# Sticky bit. Once true, we should never stay at a home that is failing or None
9696
self._home_has_passed = False
@@ -142,11 +142,7 @@ def set_last_results(self, measurements: List[Union[RunConfigMeasurement,
142142
----------
143143
measurements: List of Measurements from the last run(s)
144144
"""
145-
self._coordinate_data.increment_visit_count(self._coordinate_to_measure)
146-
self._neighborhood.coordinate_data.increment_visit_count(
147-
coordinate=self._coordinate_to_measure)
148-
149-
self._neighborhood.coordinate_data.set_measurement(
145+
self._coordinate_data.set_measurement(
150146
coordinate=self._coordinate_to_measure, measurement=measurements[0])
151147

152148
if measurements[0] is not None:
@@ -187,8 +183,8 @@ def _update_best_measurement(self, measurement: RunConfigMeasurement):
187183
self._best_coordinate = self._coordinate_to_measure
188184
self._best_measurement = measurement
189185

190-
def _get_last_results(self) -> RunConfigMeasurement:
191-
return self._neighborhood.coordinate_data.get_measurement(
186+
def _get_last_results(self) -> Optional[RunConfigMeasurement]:
187+
return self._coordinate_data.get_measurement(
192188
coordinate=self._coordinate_to_measure)
193189

194190
def _take_step(self):
@@ -245,7 +241,11 @@ def _recreate_neighborhood(self, force_slow_mode: bool):
245241
neighborhood_config = self._search_config.get_neighborhood_config()
246242

247243
self._neighborhood = Neighborhood(neighborhood_config,
248-
self._home_coordinate)
244+
self._home_coordinate,
245+
self._coordinate_data)
246+
247+
self._coordinate_data.increment_visit_count(self._home_coordinate)
248+
249249
if force_slow_mode:
250250
self._neighborhood.force_slow_mode()
251251

tests/test_coordinate_data.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,7 @@
2424

2525
class TestCoordinateData(trc.TestResultCollector):
2626

27-
def _construct_rcm(self,
28-
throughput: float,
29-
latency: float,
27+
def _construct_rcm(self, throughput: float, latency: float,
3028
config_name: str):
3129
model_config_name = [config_name]
3230

@@ -47,8 +45,7 @@ def _construct_rcm(self,
4745
gpu_metric_values={},
4846
non_gpu_metric_values=non_gpu_metric_values,
4947
metric_objectives=metric_objectives,
50-
model_config_weights=weights
51-
)
48+
model_config_weights=weights)
5249
return rcm
5350

5451
def test_basic(self):
@@ -57,6 +54,8 @@ def test_basic(self):
5754
coordinate = Coordinate([0, 0, 0])
5855
self.assertEqual(result_data.get_measurement(coordinate), None)
5956
self.assertEqual(result_data.get_visit_count(coordinate), 0)
57+
self.assertEqual(result_data.is_measured(coordinate), False)
58+
self.assertEqual(result_data.has_valid_measurement(coordinate), False)
6059

6160
def test_visit_count(self):
6261
result_data = CoordinateData()
@@ -77,24 +76,49 @@ def test_visit_count(self):
7776

7877
def test_measurement(self):
7978
"""
80-
Test if CoordinateData can properly set and get the measurements.
79+
Test if CoordinateData can properly set and get measurements
80+
81+
Also confirm that is_measured() and has_valid_measurement() work properly
8182
"""
8283
coordinate_data = CoordinateData()
8384

8485
coordinate0 = Coordinate([0, 0, 0])
8586
coordinate1 = Coordinate([0, 4, 1])
87+
coordinate2 = Coordinate([1, 2, 3])
8688

8789
rcm0 = self._construct_rcm(10, 5, config_name="modelA_config_0")
8890
rcm1 = self._construct_rcm(20, 8, config_name="modelB_config_0")
91+
rcm2 = None
8992

9093
coordinate_data.set_measurement(coordinate0, rcm0)
94+
coordinate_data.set_measurement(coordinate1, rcm1)
95+
coordinate_data.set_measurement(coordinate2, rcm2)
96+
97+
self.assertEqual(coordinate_data.is_measured(coordinate0), True)
98+
self.assertEqual(coordinate_data.is_measured(coordinate1), True)
99+
self.assertEqual(coordinate_data.is_measured(coordinate2), True)
100+
101+
self.assertEqual(coordinate_data.has_valid_measurement(coordinate0),
102+
True)
103+
self.assertEqual(coordinate_data.has_valid_measurement(coordinate1),
104+
True)
105+
self.assertEqual(coordinate_data.has_valid_measurement(coordinate2),
106+
False)
107+
91108
measurement0 = coordinate_data.get_measurement(coordinate0)
92109
self.assertEqual("modelA_config_0", measurement0.model_variants_name())
93-
self.assertEqual(10, measurement0.get_non_gpu_metric_value("perf_throughput"))
94-
self.assertEqual(5, measurement0.get_non_gpu_metric_value("perf_latency_avg"))
110+
self.assertEqual(
111+
10, measurement0.get_non_gpu_metric_value("perf_throughput"))
112+
self.assertEqual(
113+
5, measurement0.get_non_gpu_metric_value("perf_latency_avg"))
114+
self.assertTrue(coordinate_data.is_measured(coordinate0))
95115

96-
coordinate_data.set_measurement(coordinate1, rcm1)
97116
measurement1 = coordinate_data.get_measurement(coordinate1)
98117
self.assertEqual("modelB_config_0", measurement1.model_variants_name())
99-
self.assertEqual(20, measurement1.get_non_gpu_metric_value("perf_throughput"))
100-
self.assertEqual(8, measurement1.get_non_gpu_metric_value("perf_latency_avg"))
118+
self.assertEqual(
119+
20, measurement1.get_non_gpu_metric_value("perf_throughput"))
120+
self.assertEqual(
121+
8, measurement1.get_non_gpu_metric_value("perf_latency_avg"))
122+
123+
measurement2 = coordinate_data.get_measurement(coordinate2)
124+
self.assertEqual(measurement2, None)

0 commit comments

Comments
 (0)