Skip to content

Commit 710f681

Browse files
authored
Refactor list_session_groups_test to use DataProvider mock. (#6462)
Note that this is similar to #6386. The list_session_groups_test.py has, for historical reasons, mocked its data dependencies at the "multiplexer" level. When list_sessions_groups.py was migrated to use DataProvider interface in 2020 (See #3425), its tests weren't fully converted to mock data dependencies at the "DataProvider" level. We will soon be adding more logic to hparams/list_sessions_groups.py and it would be convenient to mock data dependencies at the "DataProvider" level instead of the "multiplexer" level. This PR updates existing list_session_groups_test tests so that they mock at the "DataProvider" level. The refactor can be done fairly cleanly - just the helper functions need to change.
1 parent 44f47ef commit 710f681

File tree

1 file changed

+64
-106
lines changed

1 file changed

+64
-106
lines changed

tensorboard/plugins/hparams/list_session_groups_test.py

Lines changed: 64 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -22,58 +22,39 @@
2222

2323
from google.protobuf import text_format
2424
from tensorboard import context
25-
from tensorboard.backend.event_processing import data_provider
26-
from tensorboard.backend.event_processing import event_accumulator
27-
from tensorboard.backend.event_processing import plugin_event_multiplexer
28-
from tensorboard.compat.proto import summary_pb2
25+
from tensorboard.data import provider
2926
from tensorboard.plugins import base_plugin
3027
from tensorboard.plugins.hparams import api_pb2
3128
from tensorboard.plugins.hparams import backend_context
3229
from tensorboard.plugins.hparams import list_session_groups
3330
from tensorboard.plugins.hparams import metadata
3431
from tensorboard.plugins.hparams import plugin_data_pb2
35-
from tensorboard.plugins.scalar import metadata as scalars_metadata
3632

3733

3834
DATA_TYPE_EXPERIMENT = "experiment"
3935
DATA_TYPE_SESSION_START_INFO = "session_start_info"
4036
DATA_TYPE_SESSION_END_INFO = "session_end_info"
4137

4238

43-
# Allow us to abbreviate event_accumulator.TensorEvent
44-
TensorEvent = event_accumulator.TensorEvent # pylint: disable=invalid-name
45-
46-
4739
class ListSessionGroupsTest(tf.test.TestCase):
4840
# Make assertProtoEquals print all the diff.
4941
maxDiff = None # pylint: disable=invalid-name
5042

5143
def setUp(self):
5244
self._mock_tb_context = base_plugin.TBContext()
53-
# TODO(#3425): Remove mocking or switch to mocking data provider
54-
# APIs directly.
55-
self._mock_multiplexer = mock.create_autospec(
56-
plugin_event_multiplexer.EventMultiplexer
57-
)
58-
self._mock_tb_context.multiplexer = self._mock_multiplexer
59-
self._mock_multiplexer.PluginRunToTagToContent.side_effect = (
60-
self._mock_plugin_run_to_tag_to_content
45+
self._mock_tb_context.data_provider = mock.create_autospec(
46+
provider.DataProvider
6147
)
62-
self._mock_multiplexer.AllSummaryMetadata.side_effect = (
63-
self._mock_all_summary_metadata
48+
self._mock_tb_context.data_provider.list_tensors.side_effect = (
49+
self._mock_list_tensors
6450
)
65-
self._mock_multiplexer.SummaryMetadata.side_effect = (
66-
self._mock_summary_metadata
67-
)
68-
self._mock_multiplexer.Tensors.side_effect = self._mock_tensors
69-
self._mock_tb_context.data_provider = (
70-
data_provider.MultiplexerDataProvider(
71-
self._mock_multiplexer, "/path/to/logs"
72-
)
51+
self._mock_tb_context.data_provider.read_scalars.side_effect = (
52+
self._mock_read_scalars
7353
)
7454

75-
def _mock_all_summary_metadata(self):
76-
result = {}
55+
def _mock_list_tensors(
56+
self, ctx, *, experiment_id, plugin_name, run_tag_filter
57+
):
7758
hparams_content = {
7859
"": {
7960
metadata.EXPERIMENT_TAG: self._serialized_plugin_data(
@@ -216,54 +197,31 @@ def _mock_all_summary_metadata(self):
216197
),
217198
},
218199
}
219-
scalars_content = {
220-
"session_1": {
221-
"current_temp": b"",
222-
"delta_temp": b"",
223-
"optional_metric": b"",
224-
},
225-
"session_2": {"current_temp": b"", "delta_temp": b""},
226-
"session_3": {"current_temp": b"", "delta_temp": b""},
227-
"session_4": {"current_temp": b"", "delta_temp": b""},
228-
"session_5": {"current_temp": b"", "delta_temp": b""},
229-
}
200+
result = {}
230201
for (run, tag_to_content) in hparams_content.items():
231202
result.setdefault(run, {})
232203
for (tag, content) in tag_to_content.items():
233-
m = summary_pb2.SummaryMetadata()
234-
m.data_class = summary_pb2.DATA_CLASS_TENSOR
235-
m.plugin_data.plugin_name = metadata.PLUGIN_NAME
236-
m.plugin_data.content = content
237-
result[run][tag] = m
238-
for (run, tag_to_content) in scalars_content.items():
239-
result.setdefault(run, {})
240-
for (tag, content) in tag_to_content.items():
241-
m = summary_pb2.SummaryMetadata()
242-
m.data_class = summary_pb2.DATA_CLASS_SCALAR
243-
m.plugin_data.plugin_name = scalars_metadata.PLUGIN_NAME
244-
m.plugin_data.content = content
245-
result[run][tag] = m
246-
return result
247-
248-
def _mock_plugin_run_to_tag_to_content(self, plugin_name):
249-
result = {}
250-
for (run, tag_to_metadata) in self._mock_all_summary_metadata().items():
251-
for (tag, metadata) in tag_to_metadata.items():
252-
if metadata.plugin_data.plugin_name != plugin_name:
253-
continue
254-
result.setdefault(run, {})
255-
result[run][tag] = metadata.plugin_data.content
204+
t = provider.TensorTimeSeries(
205+
max_step=0,
206+
max_wall_time=0,
207+
plugin_content=content,
208+
description="",
209+
display_name="",
210+
)
211+
result[run][tag] = t
256212
return result
257213

258-
def _mock_summary_metadata(self, run, tag):
259-
return self._mock_all_summary_metadata()[run][tag]
260-
261-
# A mock version of EventMultiplexer.Tensors
262-
def _mock_tensors(self, run, tag):
214+
def _mock_read_scalars(
215+
self,
216+
ctx=None,
217+
*,
218+
experiment_id,
219+
plugin_name,
220+
downsample=None,
221+
run_tag_filter=None,
222+
):
263223
hparams_time_series = [
264-
TensorEvent(
265-
wall_time=123.75, step=0, tensor_proto=metadata.NULL_TENSOR
266-
)
224+
provider.ScalarDatum(wall_time=123.75, step=0, value=0.0)
267225
]
268226
result_dict = {
269227
"": {
@@ -273,131 +231,131 @@ def _mock_tensors(self, run, tag):
273231
metadata.SESSION_START_INFO_TAG: hparams_time_series[:],
274232
metadata.SESSION_END_INFO_TAG: hparams_time_series[:],
275233
"current_temp": [
276-
TensorEvent(
234+
provider.ScalarDatum(
277235
wall_time=1,
278236
step=1,
279-
tensor_proto=tf.compat.v1.make_tensor_proto(10.0),
237+
value=10.0,
280238
)
281239
],
282240
"delta_temp": [
283-
TensorEvent(
241+
provider.ScalarDatum(
284242
wall_time=1,
285243
step=1,
286-
tensor_proto=tf.compat.v1.make_tensor_proto(20.0),
244+
value=20.0,
287245
),
288-
TensorEvent(
246+
provider.ScalarDatum(
289247
wall_time=10,
290248
step=2,
291-
tensor_proto=tf.compat.v1.make_tensor_proto(15.0),
249+
value=15.0,
292250
),
293251
],
294252
"optional_metric": [
295-
TensorEvent(
253+
provider.ScalarDatum(
296254
wall_time=1,
297255
step=1,
298-
tensor_proto=tf.compat.v1.make_tensor_proto(20.0),
256+
value=20.0,
299257
),
300-
TensorEvent(
258+
provider.ScalarDatum(
301259
wall_time=2,
302260
step=20,
303-
tensor_proto=tf.compat.v1.make_tensor_proto(33.0),
261+
value=33.0,
304262
),
305263
],
306264
},
307265
"session_2": {
308266
metadata.SESSION_START_INFO_TAG: hparams_time_series[:],
309267
metadata.SESSION_END_INFO_TAG: hparams_time_series[:],
310268
"current_temp": [
311-
TensorEvent(
269+
provider.ScalarDatum(
312270
wall_time=1,
313271
step=1,
314-
tensor_proto=tf.compat.v1.make_tensor_proto(100.0),
272+
value=100.0,
315273
),
316274
],
317275
"delta_temp": [
318-
TensorEvent(
276+
provider.ScalarDatum(
319277
wall_time=1,
320278
step=1,
321-
tensor_proto=tf.compat.v1.make_tensor_proto(200.0),
279+
value=200.0,
322280
),
323-
TensorEvent(
281+
provider.ScalarDatum(
324282
wall_time=11,
325283
step=3,
326-
tensor_proto=tf.compat.v1.make_tensor_proto(150.0),
284+
value=150.0,
327285
),
328286
],
329287
},
330288
"session_3": {
331289
metadata.SESSION_START_INFO_TAG: hparams_time_series[:],
332290
metadata.SESSION_END_INFO_TAG: hparams_time_series[:],
333291
"current_temp": [
334-
TensorEvent(
292+
provider.ScalarDatum(
335293
wall_time=1,
336294
step=1,
337-
tensor_proto=tf.compat.v1.make_tensor_proto(1.0),
295+
value=1.0,
338296
),
339297
],
340298
"delta_temp": [
341-
TensorEvent(
299+
provider.ScalarDatum(
342300
wall_time=1,
343301
step=1,
344-
tensor_proto=tf.compat.v1.make_tensor_proto(2.0),
302+
value=2.0,
345303
),
346-
TensorEvent(
304+
provider.ScalarDatum(
347305
wall_time=10,
348306
step=2,
349-
tensor_proto=tf.compat.v1.make_tensor_proto(1.5),
307+
value=1.5,
350308
),
351309
],
352310
},
353311
"session_4": {
354312
metadata.SESSION_START_INFO_TAG: hparams_time_series[:],
355313
metadata.SESSION_END_INFO_TAG: hparams_time_series[:],
356314
"current_temp": [
357-
TensorEvent(
315+
provider.ScalarDatum(
358316
wall_time=1,
359317
step=1,
360-
tensor_proto=tf.compat.v1.make_tensor_proto(101.0),
318+
value=101.0,
361319
),
362320
],
363321
"delta_temp": [
364-
TensorEvent(
322+
provider.ScalarDatum(
365323
wall_time=1,
366324
step=1,
367-
tensor_proto=tf.compat.v1.make_tensor_proto(201.0),
325+
value=201.0,
368326
),
369-
TensorEvent(
327+
provider.ScalarDatum(
370328
wall_time=10,
371329
step=2,
372-
tensor_proto=tf.compat.v1.make_tensor_proto(-151.0),
330+
value=-151.0,
373331
),
374332
],
375333
},
376334
"session_5": {
377335
metadata.SESSION_START_INFO_TAG: hparams_time_series[:],
378336
metadata.SESSION_END_INFO_TAG: hparams_time_series[:],
379337
"current_temp": [
380-
TensorEvent(
338+
provider.ScalarDatum(
381339
wall_time=1,
382340
step=1,
383-
tensor_proto=tf.compat.v1.make_tensor_proto(52.0),
341+
value=52.0,
384342
),
385343
],
386344
"delta_temp": [
387-
TensorEvent(
345+
provider.ScalarDatum(
388346
wall_time=1,
389347
step=1,
390-
tensor_proto=tf.compat.v1.make_tensor_proto(2.0),
348+
value=2.0,
391349
),
392-
TensorEvent(
350+
provider.ScalarDatum(
393351
wall_time=10,
394352
step=2,
395-
tensor_proto=tf.compat.v1.make_tensor_proto(-18),
353+
value=-18,
396354
),
397355
],
398356
},
399357
}
400-
return result_dict[run][tag]
358+
return result_dict
401359

402360
def test_empty_request(self):
403361
# Since we don't allow any statuses, result should be empty.

0 commit comments

Comments
 (0)