Skip to content

Commit fd3b82e

Browse files
authored
Filter metric values in /list_session_groups. (#6550)
Handle metric-based filters for DataProvider-based hparam requests. There are three parts: * Stop sending the metric-based filters to the DataProvider. The DataProvider does not have the metric data to apply the filtering. * Generate local filters. The generated local filters are based on metrics only. Local filters are not generated for hparams so as not to repeat the work of the DataProvider. * Apply the local filters after session group metrics have been retrieved and aggregated.
1 parent 7e30969 commit fd3b82e

File tree

2 files changed

+155
-1
lines changed

2 files changed

+155
-1
lines changed

tensorboard/plugins/hparams/list_session_groups.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,15 @@ def _session_groups_from_data_provider(self):
186186
if group.sessions:
187187
self._aggregate_metrics(group)
188188

189+
extractors = _create_extractors(self._request.col_params)
190+
filters = _create_filters(
191+
self._request.col_params,
192+
extractors,
193+
# We assume the DataProvider will apply hparam filters and we do not
194+
# attempt to reapply them.
195+
include_hparam_filters=False,
196+
)
197+
session_groups = self._filter(session_groups, filters)
189198
return session_groups
190199

191200
def _build_session_groups(
@@ -552,20 +561,25 @@ def extractor_fn(session_group):
552561
# True if it should be included in the result. Currently, Filters are functions
553562
# of a single column value extracted from the session group with a given
554563
# extractor specified in the construction of the filter.
555-
def _create_filters(col_params, extractors):
564+
def _create_filters(col_params, extractors, *, include_hparam_filters=True):
556565
"""Creates filters for the given col_params.
557566
558567
Args:
559568
col_params: List of ListSessionGroupsRequest.ColParam protobufs.
560569
extractors: list of extractor functions of the same length as col_params.
561570
Each element should extract the column described by the corresponding
562571
element of col_params.
572+
include_hparam_filters: bool that indicates whether hparam filters should
573+
be generated. Defaults to True.
563574
Returns:
564575
A list of filter functions. Each corresponding to a single
565576
col_params.filter oneof field of _request
566577
"""
567578
result = []
568579
for col_param, extractor in zip(col_params, extractors):
580+
if not include_hparam_filters and col_param.hparam:
581+
continue
582+
569583
a_filter = _create_filter(col_param, extractor)
570584
if a_filter:
571585
result.append(a_filter)
@@ -860,6 +874,11 @@ def _build_data_provider_filters(col_params):
860874
"""Builds HyperparameterFilters from ColParams."""
861875
filters = []
862876
for col_param in col_params:
877+
if not col_param.hparam:
878+
# We do not pass metric filters to the DataProvider as it does not
879+
# have the metric data for filtering.
880+
continue
881+
863882
fltr = _build_data_provider_filter(col_param)
864883
if fltr is None:
865884
continue

tensorboard/plugins/hparams/list_session_groups_test.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1536,6 +1536,21 @@ def test_experiment_from_data_provider_sends_discrete_filter(self):
15361536
],
15371537
)
15381538

1539+
def test_experiment_from_data_provider_does_not_send_metric_filters(self):
1540+
self._mock_tb_context.data_provider.list_tensors.side_effect = None
1541+
request = """
1542+
col_params: {
1543+
metric: { tag: 'delta_temp' }
1544+
filter_interval: {
1545+
min_value: 0
1546+
max_value: 100
1547+
}
1548+
}
1549+
"""
1550+
self._run_handler(request)
1551+
1552+
self.assertEmpty(self._get_read_hyperparameters_call_filters())
1553+
15391554
def test_experiment_from_data_provider_sends_sort(self):
15401555
self._mock_tb_context.data_provider.list_tensors.side_effect = None
15411556
request = """
@@ -2169,6 +2184,126 @@ def test_experiment_from_data_provider_with_metric_values_aggregates(
21692184
response.session_groups[0].metric_values[2],
21702185
)
21712186

2187+
def test_experiment_from_data_provider_filters_by_metric_values(
2188+
self,
2189+
):
2190+
# Filters are tested in-depth elsewhere using the Tensor-based hparams.
2191+
# For DataProvider-based hparam tests we just test one filter to verify
2192+
# the filter logic is being applied.
2193+
self._mock_tb_context.data_provider.list_tensors.side_effect = None
2194+
self._hyperparameters = [
2195+
# The sessions names correspond to return values from
2196+
# _mock_list_scalars() and _mock_read_scalars() in order to
2197+
# generate metric infos and values.
2198+
provider.HyperparameterSessionGroup(
2199+
root=provider.HyperparameterSessionRun(
2200+
experiment_id="session_1", run=""
2201+
),
2202+
sessions=[
2203+
provider.HyperparameterSessionRun(
2204+
experiment_id="session_1", run=""
2205+
)
2206+
],
2207+
hyperparameter_values=[],
2208+
),
2209+
provider.HyperparameterSessionGroup(
2210+
root=provider.HyperparameterSessionRun(
2211+
experiment_id="session_2", run=""
2212+
),
2213+
sessions=[
2214+
provider.HyperparameterSessionRun(
2215+
experiment_id="session_2", run=""
2216+
)
2217+
],
2218+
hyperparameter_values=[],
2219+
),
2220+
provider.HyperparameterSessionGroup(
2221+
root=provider.HyperparameterSessionRun(
2222+
experiment_id="session_3", run=""
2223+
),
2224+
sessions=[
2225+
provider.HyperparameterSessionRun(
2226+
experiment_id="session_3", run=""
2227+
)
2228+
],
2229+
hyperparameter_values=[],
2230+
),
2231+
]
2232+
request = """
2233+
start_index: 0
2234+
slice_size: 10
2235+
"""
2236+
response = self._run_handler(request)
2237+
self.assertLen(response.session_groups, 3)
2238+
self.assertEqual("session_1", response.session_groups[0].name)
2239+
self.assertEqual("session_2", response.session_groups[1].name)
2240+
self.assertEqual("session_3", response.session_groups[2].name)
2241+
2242+
filtered_request = """
2243+
start_index: 0
2244+
slice_size: 10
2245+
col_params: {
2246+
metric: { tag: 'delta_temp' }
2247+
filter_interval: {
2248+
min_value: 0
2249+
max_value: 100
2250+
}
2251+
}
2252+
"""
2253+
filtered_response = self._run_handler(filtered_request)
2254+
# The delta_temp values for session_1, session_2, and session_3 are
2255+
# 10, 150, and 1.5, respectively. We expect session_2 to have been
2256+
# filtered out.
2257+
self.assertLen(filtered_response.session_groups, 2)
2258+
self.assertEqual("session_1", filtered_response.session_groups[0].name)
2259+
self.assertEqual("session_3", filtered_response.session_groups[1].name)
2260+
2261+
def test_experiment_from_data_provider_does_not_filter_by_hparam_values(
2262+
self,
2263+
):
2264+
# We assume the DataProvider will apply hparam filters and we do not
2265+
# attempt to reapply them.
2266+
self._mock_tb_context.data_provider.list_tensors.side_effect = None
2267+
self._hyperparameters = [
2268+
provider.HyperparameterSessionGroup(
2269+
root=provider.HyperparameterSessionRun(
2270+
experiment_id="session_1", run=""
2271+
),
2272+
sessions=[
2273+
provider.HyperparameterSessionRun(
2274+
experiment_id="session_1", run=""
2275+
)
2276+
],
2277+
hyperparameter_values=[
2278+
provider.HyperparameterValue(
2279+
hyperparameter_name="hparam1",
2280+
domain_type=provider.HyperparameterDomainType.INTERVAL,
2281+
value=-1.0,
2282+
),
2283+
],
2284+
),
2285+
]
2286+
request = """
2287+
start_index: 0
2288+
slice_size: 10
2289+
col_params: {
2290+
hparam: 'hparam1'
2291+
filter_interval: {
2292+
min_value: 0
2293+
max_value: 100
2294+
}
2295+
}
2296+
"""
2297+
response = self._run_handler(request)
2298+
# The one result from the DataProvider call is returned even though
2299+
# there is an hparam filter that it should not pass. This indicates we
2300+
# are purposefully not applying the hparam filters.
2301+
#
2302+
# Note: The scenario should not happen in practice as we'd expect
2303+
# the DataProvider to have successfully applied the filter.
2304+
self.assertLen(response.session_groups, 1)
2305+
self.assertEqual("session_1", response.session_groups[0].name)
2306+
21722307
def _run_handler(self, request):
21732308
request_proto = api_pb2.ListSessionGroupsRequest()
21742309
text_format.Merge(request, request_proto)

0 commit comments

Comments
 (0)