@@ -1536,6 +1536,21 @@ def test_experiment_from_data_provider_sends_discrete_filter(self):
1536
1536
],
1537
1537
)
1538
1538
1539
+ def test_experiment_from_data_provider_does_not_send_metric_filters (self ):
1540
+ self ._mock_tb_context .data_provider .list_tensors .side_effect = None
1541
+ request = """
1542
+ col_params: {
1543
+ metric: { tag: 'delta_temp' }
1544
+ filter_interval: {
1545
+ min_value: 0
1546
+ max_value: 100
1547
+ }
1548
+ }
1549
+ """
1550
+ self ._run_handler (request )
1551
+
1552
+ self .assertEmpty (self ._get_read_hyperparameters_call_filters ())
1553
+
1539
1554
def test_experiment_from_data_provider_sends_sort (self ):
1540
1555
self ._mock_tb_context .data_provider .list_tensors .side_effect = None
1541
1556
request = """
@@ -2169,6 +2184,126 @@ def test_experiment_from_data_provider_with_metric_values_aggregates(
2169
2184
response .session_groups [0 ].metric_values [2 ],
2170
2185
)
2171
2186
2187
+ def test_experiment_from_data_provider_filters_by_metric_values (
2188
+ self ,
2189
+ ):
2190
+ # Filters are tested in-depth elsewhere using the Tensor-based hparams.
2191
+ # For DataProvider-based hparam tests we just test one filter to verify
2192
+ # the filter logic is being applied.
2193
+ self ._mock_tb_context .data_provider .list_tensors .side_effect = None
2194
+ self ._hyperparameters = [
2195
+ # The sessions names correspond to return values from
2196
+ # _mock_list_scalars() and _mock_read_scalars() in order to
2197
+ # generate metric infos and values.
2198
+ provider .HyperparameterSessionGroup (
2199
+ root = provider .HyperparameterSessionRun (
2200
+ experiment_id = "session_1" , run = ""
2201
+ ),
2202
+ sessions = [
2203
+ provider .HyperparameterSessionRun (
2204
+ experiment_id = "session_1" , run = ""
2205
+ )
2206
+ ],
2207
+ hyperparameter_values = [],
2208
+ ),
2209
+ provider .HyperparameterSessionGroup (
2210
+ root = provider .HyperparameterSessionRun (
2211
+ experiment_id = "session_2" , run = ""
2212
+ ),
2213
+ sessions = [
2214
+ provider .HyperparameterSessionRun (
2215
+ experiment_id = "session_2" , run = ""
2216
+ )
2217
+ ],
2218
+ hyperparameter_values = [],
2219
+ ),
2220
+ provider .HyperparameterSessionGroup (
2221
+ root = provider .HyperparameterSessionRun (
2222
+ experiment_id = "session_3" , run = ""
2223
+ ),
2224
+ sessions = [
2225
+ provider .HyperparameterSessionRun (
2226
+ experiment_id = "session_3" , run = ""
2227
+ )
2228
+ ],
2229
+ hyperparameter_values = [],
2230
+ ),
2231
+ ]
2232
+ request = """
2233
+ start_index: 0
2234
+ slice_size: 10
2235
+ """
2236
+ response = self ._run_handler (request )
2237
+ self .assertLen (response .session_groups , 3 )
2238
+ self .assertEqual ("session_1" , response .session_groups [0 ].name )
2239
+ self .assertEqual ("session_2" , response .session_groups [1 ].name )
2240
+ self .assertEqual ("session_3" , response .session_groups [2 ].name )
2241
+
2242
+ filtered_request = """
2243
+ start_index: 0
2244
+ slice_size: 10
2245
+ col_params: {
2246
+ metric: { tag: 'delta_temp' }
2247
+ filter_interval: {
2248
+ min_value: 0
2249
+ max_value: 100
2250
+ }
2251
+ }
2252
+ """
2253
+ filtered_response = self ._run_handler (filtered_request )
2254
+ # The delta_temp values for session_1, session_2, and session_3 are
2255
+ # 10, 150, and 1.5, respectively. We expect session_2 to have been
2256
+ # filtered out.
2257
+ self .assertLen (filtered_response .session_groups , 2 )
2258
+ self .assertEqual ("session_1" , filtered_response .session_groups [0 ].name )
2259
+ self .assertEqual ("session_3" , filtered_response .session_groups [1 ].name )
2260
+
2261
+ def test_experiment_from_data_provider_does_not_filter_by_hparam_values (
2262
+ self ,
2263
+ ):
2264
+ # We assume the DataProvider will apply hparam filters and we do not
2265
+ # attempt to reapply them.
2266
+ self ._mock_tb_context .data_provider .list_tensors .side_effect = None
2267
+ self ._hyperparameters = [
2268
+ provider .HyperparameterSessionGroup (
2269
+ root = provider .HyperparameterSessionRun (
2270
+ experiment_id = "session_1" , run = ""
2271
+ ),
2272
+ sessions = [
2273
+ provider .HyperparameterSessionRun (
2274
+ experiment_id = "session_1" , run = ""
2275
+ )
2276
+ ],
2277
+ hyperparameter_values = [
2278
+ provider .HyperparameterValue (
2279
+ hyperparameter_name = "hparam1" ,
2280
+ domain_type = provider .HyperparameterDomainType .INTERVAL ,
2281
+ value = - 1.0 ,
2282
+ ),
2283
+ ],
2284
+ ),
2285
+ ]
2286
+ request = """
2287
+ start_index: 0
2288
+ slice_size: 10
2289
+ col_params: {
2290
+ hparam: 'hparam1'
2291
+ filter_interval: {
2292
+ min_value: 0
2293
+ max_value: 100
2294
+ }
2295
+ }
2296
+ """
2297
+ response = self ._run_handler (request )
2298
+ # The one result from the DataProvider call is returned even though
2299
+ # there is an hparam filter that it should not pass. This indicates we
2300
+ # are purposefully not applying the hparam filters.
2301
+ #
2302
+ # Note: The scenario should not happen in practice as we'd expect
2303
+ # the DataProvider to have successfully applied the filter.
2304
+ self .assertLen (response .session_groups , 1 )
2305
+ self .assertEqual ("session_1" , response .session_groups [0 ].name )
2306
+
2172
2307
def _run_handler (self , request ):
2173
2308
request_proto = api_pb2 .ListSessionGroupsRequest ()
2174
2309
text_format .Merge (request , request_proto )
0 commit comments