@@ -50,32 +50,105 @@ def __init__(
50
50
self ._request = request
51
51
self ._extractors = _create_extractors (request .col_params )
52
52
self ._filters = _create_filters (request .col_params , self ._extractors )
53
+
54
+ def run (self ):
55
+ """Handles the request specified on construction.
56
+
57
+ This operation first attempts to construct SessionGroup information
58
+ from hparam tags metadata.EXPERIMENT_TAG and
59
+ metadata.SESSION_START_INFO.
60
+
61
+ If no such tags are found, then will build SessionGroup information
62
+ using the results from DataProvider.read_hyperparameters().
63
+
64
+ Returns:
65
+ A ListSessionGroupsResponse object.
66
+ """
67
+
68
+ session_groups_from_tags = self ._session_groups_from_tags ()
69
+ if session_groups_from_tags :
70
+ return self ._create_response (session_groups_from_tags )
71
+
72
+ session_groups_from_data_provider = (
73
+ self ._session_groups_from_data_provider ()
74
+ )
75
+ if session_groups_from_data_provider :
76
+ return self ._create_response (session_groups_from_data_provider )
77
+
78
+ return api_pb2 .ListSessionGroupsResponse (
79
+ session_groups = [], total_size = 0
80
+ )
81
+
82
+ def _session_groups_from_tags (self ):
83
+ """Constructs lists of SessionGroups based on hparam tag metadata."""
53
84
# Query for all Hparams summary metadata up front to minimize calls to
54
85
# the underlying DataProvider.
55
- self ._hparams_run_to_tag_to_content = backend_context .hparams_metadata (
56
- request_context , experiment_id
86
+ self ._hparams_run_to_tag_to_content = (
87
+ self ._backend_context .hparams_metadata (
88
+ self ._request_context , self ._experiment_id
89
+ )
57
90
)
58
91
# Since an context.experiment() call may search through all the runs, we
59
92
# cache it here.
60
- self ._experiment = backend_context .experiment_from_metadata (
61
- request_context ,
62
- experiment_id ,
93
+ self ._experiment = self . _backend_context .experiment_from_metadata (
94
+ self . _request_context ,
95
+ self . _experiment_id ,
63
96
self ._hparams_run_to_tag_to_content ,
64
- self . _backend_context . hparams_from_data_provider (
65
- request_context , experiment_id
66
- ) ,
97
+ # Don't pass any information from the DataProvider since we are only
98
+ # examining session groups based on tag metadata
99
+ [] ,
67
100
)
68
101
69
- def run (self ):
70
- """Handles the request specified on construction.
71
-
72
- Returns:
73
- A ListSessionGroupsResponse object.
74
- """
75
102
session_groups = self ._build_session_groups ()
76
103
session_groups = self ._filter (session_groups )
77
104
self ._sort (session_groups )
78
- return self ._create_response (session_groups )
105
+ return session_groups
106
+
107
+ def _session_groups_from_data_provider (self ):
108
+ """Constructs lists of SessionGroups based on DataProvider results."""
109
+ response = self ._backend_context .session_groups_from_data_provider (
110
+ self ._request_context , self ._experiment_id
111
+ )
112
+
113
+ session_groups = []
114
+ for provider_group in response :
115
+ sessions = [
116
+ api_pb2 .Session (name = f"{ s .experiment_id } /{ s .run } " )
117
+ for s in provider_group .sessions
118
+ ]
119
+ name = (
120
+ f"{ provider_group .root .experiment_id } /{ provider_group .root .run } "
121
+ if provider_group .root .run
122
+ else provider_group .root .experiment_id
123
+ )
124
+ session_group = api_pb2 .SessionGroup (
125
+ name = name ,
126
+ sessions = sessions ,
127
+ )
128
+
129
+ for provider_hparam in provider_group .hyperparameter_values :
130
+ hparam = session_group .hparams [
131
+ provider_hparam .hyperparameter_name
132
+ ]
133
+ if (
134
+ provider_hparam .domain_type
135
+ == provider .HyperparameterDomainType .DISCRETE_STRING
136
+ ):
137
+ hparam .string_value = provider_hparam .value
138
+ elif provider_hparam .domain_type in [
139
+ provider .HyperparameterDomainType .DISCRETE_FLOAT ,
140
+ provider .HyperparameterDomainType .INTERVAL ,
141
+ ]:
142
+ hparam .number_value = provider_hparam .value
143
+ elif (
144
+ provider_hparam .domain_type
145
+ == provider .HyperparameterDomainType .DISCRETE_BOOL
146
+ ):
147
+ hparam .bool_value = provider_hparam .value
148
+
149
+ session_groups .append (session_group )
150
+
151
+ return session_groups
79
152
80
153
def _build_session_groups (self ):
81
154
"""Returns a list of SessionGroups protobuffers from the summary
0 commit comments