@@ -49,8 +49,6 @@ def __init__(
49
49
self ._backend_context = backend_context
50
50
self ._experiment_id = experiment_id
51
51
self ._request = request
52
- self ._extractors = _create_extractors (request .col_params )
53
- self ._filters = _create_filters (request .col_params , self ._extractors )
54
52
55
53
def run (self ):
56
54
"""Handles the request specified on construction.
@@ -82,27 +80,29 @@ def run(self):
82
80
83
81
def _session_groups_from_tags (self ):
84
82
"""Constructs lists of SessionGroups based on hparam tag metadata."""
85
- # Query for all Hparams summary metadata up front to minimize calls to
83
+ # Query for all Hparams summary metadata one time to minimize calls to
86
84
# the underlying DataProvider.
87
- self ._hparams_run_to_tag_to_content = (
88
- self ._backend_context .hparams_metadata (
89
- self ._request_context , self ._experiment_id
90
- )
85
+ hparams_run_to_tag_to_content = self ._backend_context .hparams_metadata (
86
+ self ._request_context , self ._experiment_id
91
87
)
92
- # Since an context.experiment() call may search through all the runs, we
93
- # cache it here .
94
- self . _experiment = self ._backend_context .experiment_from_metadata (
88
+ # Construct the experiment one time since an context.experiment() call
89
+ # may search through all the runs .
90
+ experiment = self ._backend_context .experiment_from_metadata (
95
91
self ._request_context ,
96
92
self ._experiment_id ,
97
- self . _hparams_run_to_tag_to_content ,
93
+ hparams_run_to_tag_to_content ,
98
94
# Don't pass any information from the DataProvider since we are only
99
95
# examining session groups based on tag metadata
100
96
[],
101
97
)
98
+ extractors = _create_extractors (self ._request .col_params )
99
+ filters = _create_filters (self ._request .col_params , extractors )
102
100
103
- session_groups = self ._build_session_groups ()
104
- session_groups = self ._filter (session_groups )
105
- self ._sort (session_groups )
101
+ session_groups = self ._build_session_groups (
102
+ hparams_run_to_tag_to_content , experiment
103
+ )
104
+ session_groups = self ._filter (session_groups , filters )
105
+ self ._sort (session_groups , extractors )
106
106
return session_groups
107
107
108
108
def _session_groups_from_data_provider (self ):
@@ -151,7 +151,7 @@ def _session_groups_from_data_provider(self):
151
151
152
152
return session_groups
153
153
154
- def _build_session_groups (self ):
154
+ def _build_session_groups (self , hparams_run_to_tag_to_content , experiment ):
155
155
"""Returns a list of SessionGroups protobuffers from the summary
156
156
data."""
157
157
@@ -167,13 +167,13 @@ def _build_session_groups(self):
167
167
# contain metrics (may be in subdirectories).
168
168
session_names = [
169
169
run
170
- for (run , tags ) in self . _hparams_run_to_tag_to_content .items ()
170
+ for (run , tags ) in hparams_run_to_tag_to_content .items ()
171
171
if metadata .SESSION_START_INFO_TAG in tags
172
172
]
173
173
metric_runs = set ()
174
174
metric_tags = set ()
175
175
for session_name in session_names :
176
- for metric in self . _experiment .metric_infos :
176
+ for metric in experiment .metric_infos :
177
177
metric_name = metric .name
178
178
(run , tag ) = metrics .run_tag_from_session_and_metric (
179
179
session_name , metric_name
@@ -190,7 +190,7 @@ def _build_session_groups(self):
190
190
for (
191
191
session_name ,
192
192
tag_to_content ,
193
- ) in self . _hparams_run_to_tag_to_content .items ():
193
+ ) in hparams_run_to_tag_to_content .items ():
194
194
if metadata .SESSION_START_INFO_TAG not in tag_to_content :
195
195
continue
196
196
start_info = metadata .parse_session_start_info_plugin_data (
@@ -202,7 +202,7 @@ def _build_session_groups(self):
202
202
tag_to_content [metadata .SESSION_END_INFO_TAG ]
203
203
)
204
204
session = self ._build_session (
205
- session_name , start_info , end_info , all_metric_evals
205
+ experiment , session_name , start_info , end_info , all_metric_evals
206
206
)
207
207
if session .status in self ._request .allowed_statuses :
208
208
self ._add_session (session , start_info , groups_by_name )
@@ -257,7 +257,9 @@ def _add_session(self, session, start_info, groups_by_name):
257
257
group .hparams [key ].CopyFrom (value )
258
258
groups_by_name [group_name ] = group
259
259
260
- def _build_session (self , name , start_info , end_info , all_metric_evals ):
260
+ def _build_session (
261
+ self , experiment , name , start_info , end_info , all_metric_evals
262
+ ):
261
263
"""Builds a session object."""
262
264
263
265
assert start_info is not None
@@ -266,7 +268,7 @@ def _build_session(self, name, start_info, end_info, all_metric_evals):
266
268
start_time_secs = start_info .start_time_secs ,
267
269
model_uri = start_info .model_uri ,
268
270
metric_values = self ._build_session_metric_values (
269
- name , all_metric_evals
271
+ experiment , name , all_metric_evals
270
272
),
271
273
monitor_url = start_info .monitor_url ,
272
274
)
@@ -275,13 +277,14 @@ def _build_session(self, name, start_info, end_info, all_metric_evals):
275
277
result .end_time_secs = end_info .end_time_secs
276
278
return result
277
279
278
- def _build_session_metric_values (self , session_name , all_metric_evals ):
280
+ def _build_session_metric_values (
281
+ self , experiment , session_name , all_metric_evals
282
+ ):
279
283
"""Builds the session metric values."""
280
284
281
285
# result is a list of api_pb2.MetricValue instances.
282
286
result = []
283
- metric_infos = self ._experiment .metric_infos
284
- for metric_info in metric_infos :
287
+ for metric_info in experiment .metric_infos :
285
288
metric_name = metric_info .name
286
289
(run , tag ) = metrics .run_tag_from_session_and_metric (
287
290
session_name , metric_name
@@ -327,13 +330,15 @@ def _aggregate_metrics(self, session_group):
327
330
% self ._request .aggregation_type
328
331
)
329
332
330
- def _filter (self , session_groups ):
331
- return [sg for sg in session_groups if self ._passes_all_filters (sg )]
333
+ def _filter (self , session_groups , filters ):
334
+ return [
335
+ sg for sg in session_groups if self ._passes_all_filters (sg , filters )
336
+ ]
332
337
333
- def _passes_all_filters (self , session_group ):
334
- return all (filter_fn (session_group ) for filter_fn in self . _filters )
338
+ def _passes_all_filters (self , session_group , filters ):
339
+ return all (filter_fn (session_group ) for filter_fn in filters )
335
340
336
- def _sort (self , session_groups ):
341
+ def _sort (self , session_groups , extractors ):
337
342
"""Sorts 'session_groups' in place according to _request.col_params."""
338
343
339
344
# Sort by session_group name so we have a deterministic order.
@@ -344,7 +349,7 @@ def _sort(self, session_groups):
344
349
# need to iterate on these columns in reverse order (thus the primary key
345
350
# is the key used in the last sort).
346
351
for col_param , extractor in reversed (
347
- list (zip (self ._request .col_params , self . _extractors ))
352
+ list (zip (self ._request .col_params , extractors ))
348
353
):
349
354
if col_param .order == api_pb2 .ORDER_UNSPECIFIED :
350
355
continue
0 commit comments