Skip to content

Commit 3bae61c

Browse files
authored
HParam: Add a new selector to create columns for hparams (#6406)
## Motivation for features / changes As part of the upcoming hparams in timeseries feature we want to be able to display hparams as columns in both the scalar card data table and the runs table. Because we do not want all the hparams displayed all the time (there may be too many) we are building an interface for selecting columns. Towards that goal I have created this selector which finds all the hparams and generates a column header for them. ## Screenshots of UI changes (or N/A) N/A
1 parent 917cc7e commit 3bae61c

File tree

6 files changed

+113
-2
lines changed

6 files changed

+113
-2
lines changed

tensorboard/webapp/metrics/views/main_view/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ tf_ts_library(
8989
"//tensorboard/webapp/metrics/store",
9090
"//tensorboard/webapp/metrics/views:types",
9191
"//tensorboard/webapp/metrics/views:utils",
92+
"//tensorboard/webapp/metrics/views/card_renderer:scalar_card_types",
9293
"//tensorboard/webapp/runs:types",
9394
"//tensorboard/webapp/runs/views/runs_table:types",
9495
"//tensorboard/webapp/util:matcher",
@@ -194,13 +195,15 @@ tf_ts_library(
194195
"//tensorboard/webapp/customization",
195196
"//tensorboard/webapp/experiments/store:testing",
196197
"//tensorboard/webapp/hparams:types",
198+
"//tensorboard/webapp/hparams/_redux:testing",
197199
"//tensorboard/webapp/metrics:test_lib",
198200
"//tensorboard/webapp/metrics:types",
199201
"//tensorboard/webapp/metrics/actions",
200202
"//tensorboard/webapp/metrics/data_source",
201203
"//tensorboard/webapp/metrics/store",
202204
"//tensorboard/webapp/metrics/views:types",
203205
"//tensorboard/webapp/metrics/views/card_renderer",
206+
"//tensorboard/webapp/metrics/views/card_renderer:scalar_card_types",
204207
"//tensorboard/webapp/runs/store:selectors",
205208
"//tensorboard/webapp/runs/store:testing",
206209
"//tensorboard/webapp/runs/store:types",

tensorboard/webapp/metrics/views/main_view/common_selectors.ts

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,14 @@ import {
2525
getRunSelectorRegexFilter,
2626
getRouteKind,
2727
getRunsFromExperimentIds,
28+
getColumnHeadersForCard,
29+
getCardMetadata,
2830
} from '../../../selectors';
2931
import {DeepReadonly} from '../../../util/types';
3032
import {
3133
getHparamFilterMapFromExperimentIds,
3234
getMetricFilterMapFromExperimentIds,
35+
getExperimentsHparamsAndMetricsSpecs,
3336
} from '../../../hparams/_redux/hparams_selectors';
3437
import {
3538
DiscreteFilter,
@@ -48,6 +51,10 @@ import {compareTagNames} from '../../utils';
4851
import {CardIdWithMetadata} from '../metrics_view_types';
4952
import {RouteKind} from '../../../app_routing/types';
5053
import {memoize} from '../../../util/memoize';
54+
import {
55+
ColumnHeader,
56+
ColumnHeaderType,
57+
} from '../card_renderer/scalar_card_types';
5158

5259
export const getScalarTagsForRunSelection = createSelector(
5360
getMetricsTagMetadata,
@@ -254,6 +261,39 @@ export const getFilteredRenderableRunsIdsFromRoute = createSelector(
254261
}
255262
);
256263

264+
export const getPotentialHparamColumns = createSelector(
265+
(state: State) => state,
266+
getExperimentIdsFromRoute,
267+
(state, experimentIds): ColumnHeader[] => {
268+
if (!experimentIds) {
269+
return [];
270+
}
271+
272+
const {hparams} = getExperimentsHparamsAndMetricsSpecs(state, {
273+
experimentIds,
274+
});
275+
276+
return hparams.map((spec) => ({
277+
type: ColumnHeaderType.HPARAM,
278+
name: spec.name,
279+
// According to the api spec when the displayName is empty, the name should
280+
// be displayed tensorboard/plugins/hparams/api.proto
281+
displayName: spec.displayName || spec.name,
282+
enabled: false,
283+
}));
284+
}
285+
);
286+
287+
export const getAllPotentialColumnsForCard = memoize((cardId: string) => {
288+
return createSelector(
289+
getColumnHeadersForCard(cardId),
290+
getPotentialHparamColumns,
291+
(staticColumnHeaders, potentialHparamColumns) => {
292+
return [...staticColumnHeaders, ...potentialHparamColumns];
293+
}
294+
);
295+
});
296+
257297
export const factories = {
258298
getRenderableRuns,
259299
getFilteredRenderableRuns,

tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515
import {RouteKind} from '../../../app_routing';
16+
import {
17+
buildSpecs,
18+
buildHparamSpec,
19+
buildMetricSpec,
20+
} from '../../../hparams/_redux/testing';
1621
import {
1722
buildAppRoutingState,
1823
buildStateFromAppRoutingState,
@@ -35,6 +40,7 @@ import {
3540
} from '../../testing';
3641
import {PluginType} from '../../types';
3742
import * as selectors from './common_selectors';
43+
import {ColumnHeaderType} from '../card_renderer/scalar_card_types';
3844

3945
describe('common selectors', () => {
4046
let runIds: Record<string, string[]>;
@@ -171,6 +177,18 @@ describe('common selectors', () => {
171177
},
172178
},
173179
} as any,
180+
hparams: {
181+
specs: buildSpecs('defaultExperimentId', {
182+
hparam: {
183+
specs: [buildHparamSpec({name: 'foo', displayName: 'Foo'})],
184+
defaultFilters: new Map(),
185+
},
186+
metric: {
187+
specs: [buildMetricSpec({displayName: 'Bar'})],
188+
defaultFilters: new Map(),
189+
},
190+
}),
191+
} as any,
174192
});
175193
});
176194

@@ -968,4 +986,43 @@ describe('common selectors', () => {
968986
expect(result).toEqual(new Set());
969987
});
970988
});
989+
990+
describe('getPotentialHparamColumns', () => {
991+
it('returns empty list when there are no experiments', () => {
992+
state.app_routing!.activeRoute!.routeKind = RouteKind.EXPERIMENTS;
993+
994+
expect(selectors.getPotentialHparamColumns(state)).toEqual([]);
995+
});
996+
997+
it('creates columns for each hparam', () => {
998+
expect(selectors.getPotentialHparamColumns(state)).toEqual([
999+
{
1000+
type: ColumnHeaderType.HPARAM,
1001+
name: 'foo',
1002+
displayName: 'Foo',
1003+
enabled: false,
1004+
},
1005+
]);
1006+
});
1007+
1008+
it('sets name as display name when a display name is not provided', () => {
1009+
state.hparams!.specs['defaultExperimentId'].hparam.specs.push(
1010+
buildHparamSpec({name: 'bar', displayName: ''})
1011+
);
1012+
expect(selectors.getPotentialHparamColumns(state)).toEqual([
1013+
{
1014+
type: ColumnHeaderType.HPARAM,
1015+
name: 'foo',
1016+
displayName: 'Foo',
1017+
enabled: false,
1018+
},
1019+
{
1020+
type: ColumnHeaderType.HPARAM,
1021+
name: 'bar',
1022+
displayName: 'bar',
1023+
enabled: false,
1024+
},
1025+
]);
1026+
});
1027+
});
9711028
});

tensorboard/webapp/testing/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ tf_ts_library(
7878
"//tensorboard/webapp/feature_flag/store:testing",
7979
"//tensorboard/webapp/feature_flag/store:types",
8080
"//tensorboard/webapp/hparams/_redux:testing",
81+
"//tensorboard/webapp/hparams/_redux:types",
8182
"//tensorboard/webapp/metrics:test_lib",
8283
"//tensorboard/webapp/metrics/store",
8384
"//tensorboard/webapp/notification_center/_redux:testing",

tensorboard/webapp/testing/utils.ts

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ import {
5656
createState as createSettings,
5757
createSettingsState,
5858
} from '../settings/testing';
59+
import {HPARAMS_FEATURE_KEY} from '../hparams/_redux/types';
5960

6061
export function buildMockState(overrides: Partial<State> = {}): State {
6162
return {
@@ -81,7 +82,12 @@ export function buildMockState(overrides: Partial<State> = {}): State {
8182
buildAppRoutingState(overrides[APP_ROUTING_FEATURE_KEY])
8283
),
8384
...buildStateFromFeatureFlagsState(buildFeatureFlagState()),
84-
...buildStateFromHparamsState(buildHparamsState()),
85+
...buildStateFromHparamsState(
86+
buildHparamsState(
87+
overrides[HPARAMS_FEATURE_KEY]?.specs,
88+
overrides[HPARAMS_FEATURE_KEY]?.filters
89+
)
90+
),
8591
...buildStateFromNotificationState(
8692
buildNotificationState(overrides[NOTIFICATION_FEATURE_KEY] || {})
8793
),

tensorboard/webapp/widgets/data_table/data_table_component.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,11 @@ export class DataTableComponent implements OnDestroy {
9595
case ColumnHeaderType.STEP_AT_MAX:
9696
case ColumnHeaderType.STEP_AT_MIN:
9797
case ColumnHeaderType.MEAN:
98-
return intlNumberFormatter.formatShort(datum as number);
98+
case ColumnHeaderType.HPARAM:
99+
if (typeof datum === 'number') {
100+
return intlNumberFormatter.formatShort(datum as number);
101+
}
102+
return datum;
99103
case ColumnHeaderType.TIME:
100104
const time = new Date(datum!);
101105
return time.toISOString();

0 commit comments

Comments
 (0)