Skip to content

Commit 304046d

Browse files
authored
Adds common selectors for shared hparams (#6732)
## Motivation for features / changes Displaying shared hparam columns in the runs and scalar tables requires new selectors. ## Technical description of changes Adds two new common selectors that will be used by both runs and scalar tables: - getSelectableColumns - getGroupedColumns Also updates getDashboardDisplayedHparamColumns to only return relevant hparams (i.e. those with specs defined by selected experiments) ## Detailed steps to verify changes work correctly (as executed by you) Unit tests pass
1 parent d6ad97e commit 304046d

File tree

19 files changed

+708
-54
lines changed

19 files changed

+708
-54
lines changed

tensorboard/webapp/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ tf_ng_web_test_suite(
301301
"//tensorboard/webapp/widgets/content_wrapping_input:content_wrapping_input_tests",
302302
"//tensorboard/webapp/widgets/custom_modal:custom_modal_test",
303303
"//tensorboard/webapp/widgets/data_table:data_table_test",
304+
"//tensorboard/webapp/widgets/data_table:utils_test",
304305
"//tensorboard/webapp/widgets/dropdown:dropdown_tests",
305306
"//tensorboard/webapp/widgets/experiment_alias:experiment_alias_test",
306307
"//tensorboard/webapp/widgets/filter_input:filter_input_test",

tensorboard/webapp/hparams/_redux/hparams_selectors.ts

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,14 @@ export const getDashboardDefaultHparamFilters = createSelector(
4747
);
4848

4949
export const getDashboardDisplayedHparamColumns = createSelector(
50+
getDashboardHparamsAndMetricsSpecs,
5051
getHparamsState,
51-
(state) => state.dashboardDisplayedHparamColumns
52+
({hparams}, state) => {
53+
const hparamSet = new Set(hparams.map((hparam) => hparam.name));
54+
return state.dashboardDisplayedHparamColumns.filter((column) =>
55+
hparamSet.has(column.name)
56+
);
57+
}
5258
);
5359

5460
export const getDashboardHparamFilterMap = createSelector(

tensorboard/webapp/hparams/_redux/hparams_selectors_test.ts

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License.
1515

1616
import {ColumnHeaderType} from '../../widgets/data_table/types';
1717
import {DomainType} from '../types';
18+
import {State} from './types';
1819
import * as selectors from './hparams_selectors';
1920
import {
2021
buildHparamSpec,
@@ -114,30 +115,63 @@ describe('hparams/_redux/hparams_selectors_test', () => {
114115
});
115116

116117
describe('#getDashboardDisplayedHparamColumns', () => {
117-
it('returns dashboard displayed hparam columns', () => {
118-
const fakeColumns = [
119-
{
120-
type: ColumnHeaderType.HPARAM,
121-
name: 'conv_layers',
122-
displayName: 'Conv Layers',
123-
enabled: true,
124-
},
125-
{
126-
type: ColumnHeaderType.HPARAM,
127-
name: 'conv_kernel_size',
128-
displayName: 'Conv Kernel Size',
129-
enabled: true,
130-
},
131-
];
118+
it('returns no columns if no hparam specs', () => {
132119
const state = buildStateFromHparamsState(
133120
buildHparamsState({
134-
dashboardDisplayedHparamColumns: fakeColumns,
121+
dashboardSpecs: {
122+
hparams: [],
123+
},
124+
dashboardDisplayedHparamColumns: [
125+
{
126+
type: ColumnHeaderType.HPARAM,
127+
name: 'conv_layers',
128+
displayName: 'Conv Layers',
129+
enabled: true,
130+
},
131+
{
132+
type: ColumnHeaderType.HPARAM,
133+
name: 'conv_kernel_size',
134+
displayName: 'Conv Kernel Size',
135+
enabled: true,
136+
},
137+
],
135138
})
136139
);
137140

138-
expect(selectors.getDashboardDisplayedHparamColumns(state)).toEqual(
139-
fakeColumns
141+
expect(selectors.getDashboardDisplayedHparamColumns(state)).toEqual([]);
142+
});
143+
144+
it('returns only hparam columns that have specs', () => {
145+
const state = buildStateFromHparamsState(
146+
buildHparamsState({
147+
dashboardSpecs: {
148+
hparams: [buildHparamSpec({name: 'conv_layers'})],
149+
},
150+
dashboardDisplayedHparamColumns: [
151+
{
152+
type: ColumnHeaderType.HPARAM,
153+
name: 'conv_layers',
154+
displayName: 'Conv Layers',
155+
enabled: true,
156+
},
157+
{
158+
type: ColumnHeaderType.HPARAM,
159+
name: 'conv_kernel_size',
160+
displayName: 'Conv Kernel Size',
161+
enabled: true,
162+
},
163+
],
164+
})
140165
);
166+
167+
expect(selectors.getDashboardDisplayedHparamColumns(state)).toEqual([
168+
{
169+
type: ColumnHeaderType.HPARAM,
170+
name: 'conv_layers',
171+
displayName: 'Conv Layers',
172+
enabled: true,
173+
},
174+
]);
141175
});
142176
});
143177
});

tensorboard/webapp/metrics/store/BUILD

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ tf_ts_library(
1818
"//tensorboard/webapp/app_routing:types",
1919
"//tensorboard/webapp/app_routing/actions",
2020
"//tensorboard/webapp/core/actions",
21+
"//tensorboard/webapp/hparams/_redux:hparams_selectors",
2122
"//tensorboard/webapp/metrics:types",
2223
"//tensorboard/webapp/metrics:utils",
2324
"//tensorboard/webapp/metrics/actions",
@@ -33,6 +34,7 @@ tf_ts_library(
3334
"//tensorboard/webapp/util:types",
3435
"//tensorboard/webapp/widgets/card_fob:types",
3536
"//tensorboard/webapp/widgets/data_table:types",
37+
"//tensorboard/webapp/widgets/data_table:utils",
3638
"//tensorboard/webapp/widgets/line_chart_v2/lib:public_types",
3739
"@npm//@ngrx/store",
3840
],
@@ -97,14 +99,18 @@ tf_ts_library(
9799
"//tensorboard/webapp/app_routing:types",
98100
"//tensorboard/webapp/app_routing/actions",
99101
"//tensorboard/webapp/core/actions",
102+
"//tensorboard/webapp/hparams:testing",
103+
"//tensorboard/webapp/hparams/_redux:types",
100104
"//tensorboard/webapp/metrics:test_lib",
101105
"//tensorboard/webapp/metrics:types",
102106
"//tensorboard/webapp/metrics/actions",
103107
"//tensorboard/webapp/metrics/data_source",
104108
"//tensorboard/webapp/persistent_settings",
105109
"//tensorboard/webapp/routes:testing",
110+
"//tensorboard/webapp/testing:utils",
106111
"//tensorboard/webapp/types",
107112
"//tensorboard/webapp/util:dom",
113+
"//tensorboard/webapp/util:types",
108114
"//tensorboard/webapp/widgets/card_fob:types",
109115
"//tensorboard/webapp/widgets/data_table:types",
110116
"@npm//@types/jasmine",

tensorboard/webapp/metrics/store/metrics_selectors.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ import {
5050
import {ColumnHeader, DataTableMode} from '../../widgets/data_table/types';
5151
import {Extent} from '../../widgets/line_chart_v2/lib/public_types';
5252
import {memoize} from '../../util/memoize';
53+
import {getDashboardDisplayedHparamColumns} from '../../hparams/_redux/hparams_selectors';
54+
import {DataTableUtils} from '../../widgets/data_table/utils';
5355

5456
const selectMetricsState =
5557
createFeatureSelector<MetricsState>(METRICS_FEATURE_KEY);
@@ -661,3 +663,12 @@ export const getColumnHeadersForCard = memoize((cardId: string) => {
661663
}
662664
);
663665
});
666+
667+
export const getGroupedHeadersForCard = memoize((cardId: string) =>
668+
createSelector(
669+
getColumnHeadersForCard(cardId),
670+
getDashboardDisplayedHparamColumns,
671+
(standardColumns, hparamColumns) =>
672+
DataTableUtils.groupColumns([...standardColumns, ...hparamColumns])
673+
)
674+
);

tensorboard/webapp/metrics/store/metrics_selectors_test.ts

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,14 @@ import {
3333
} from '../../widgets/data_table/types';
3434
import * as selectors from './metrics_selectors';
3535
import {CardFeatureOverride, MetricsState} from './metrics_types';
36+
import {buildMockState} from '../../testing/utils';
37+
import {
38+
buildHparamSpec,
39+
buildHparamsState,
40+
buildStateFromHparamsState,
41+
} from '../../hparams/testing';
42+
import {DeepPartial} from '../../util/types';
43+
import {HparamsState} from '../../hparams/_redux/types';
3644

3745
describe('metrics selectors', () => {
3846
beforeEach(() => {
@@ -1745,4 +1753,192 @@ describe('metrics selectors', () => {
17451753
).toEqual(rangeSelectionHeaders);
17461754
});
17471755
});
1756+
1757+
describe('getGroupedHeadersForCard', () => {
1758+
let singleSelectionHeaders: ColumnHeader[];
1759+
let rangeSelectionHeaders: ColumnHeader[];
1760+
let hparamsState: DeepPartial<HparamsState>;
1761+
1762+
beforeEach(() => {
1763+
singleSelectionHeaders = [
1764+
{
1765+
type: ColumnHeaderType.COLOR,
1766+
name: 'color',
1767+
displayName: 'Color',
1768+
enabled: true,
1769+
},
1770+
{
1771+
type: ColumnHeaderType.RUN,
1772+
name: 'run',
1773+
displayName: 'My Run name',
1774+
enabled: false,
1775+
},
1776+
];
1777+
rangeSelectionHeaders = [
1778+
{
1779+
type: ColumnHeaderType.MEAN,
1780+
name: 'mean',
1781+
displayName: 'Mean',
1782+
enabled: true,
1783+
},
1784+
{
1785+
type: ColumnHeaderType.RUN,
1786+
name: 'run',
1787+
displayName: 'My Run name',
1788+
enabled: false,
1789+
},
1790+
];
1791+
hparamsState = {
1792+
dashboardSpecs: {
1793+
hparams: [
1794+
buildHparamSpec({name: 'conv_layers'}),
1795+
buildHparamSpec({name: 'conv_kernel_size'}),
1796+
],
1797+
},
1798+
dashboardDisplayedHparamColumns: [
1799+
{
1800+
type: ColumnHeaderType.HPARAM,
1801+
name: 'conv_layers',
1802+
displayName: 'Conv Layers',
1803+
enabled: true,
1804+
},
1805+
{
1806+
type: ColumnHeaderType.HPARAM,
1807+
name: 'conv_kernel_size',
1808+
displayName: 'Conv Kernel Size',
1809+
enabled: true,
1810+
},
1811+
],
1812+
};
1813+
});
1814+
1815+
it('returns grouped single selection headers when card range selection is disabled', () => {
1816+
const state = buildMockState({
1817+
...appStateFromMetricsState(
1818+
buildMetricsState({
1819+
singleSelectionHeaders,
1820+
rangeSelectionHeaders,
1821+
cardStateMap: {
1822+
card1: {
1823+
rangeSelectionOverride:
1824+
CardFeatureOverride.OVERRIDE_AS_DISABLED,
1825+
},
1826+
},
1827+
})
1828+
),
1829+
...buildStateFromHparamsState(buildHparamsState(hparamsState)),
1830+
});
1831+
1832+
expect(selectors.getGroupedHeadersForCard('card1')(state)).toEqual([
1833+
{
1834+
type: ColumnHeaderType.RUN,
1835+
name: 'run',
1836+
displayName: 'My Run name',
1837+
enabled: false,
1838+
},
1839+
{
1840+
type: ColumnHeaderType.HPARAM,
1841+
name: 'conv_layers',
1842+
displayName: 'Conv Layers',
1843+
enabled: true,
1844+
},
1845+
{
1846+
type: ColumnHeaderType.HPARAM,
1847+
name: 'conv_kernel_size',
1848+
displayName: 'Conv Kernel Size',
1849+
enabled: true,
1850+
},
1851+
{
1852+
type: ColumnHeaderType.COLOR,
1853+
name: 'color',
1854+
displayName: 'Color',
1855+
enabled: true,
1856+
},
1857+
]);
1858+
});
1859+
1860+
it('returns grouped range selection headers when card range selection is enabled', () => {
1861+
const state = buildMockState({
1862+
...appStateFromMetricsState(
1863+
buildMetricsState({
1864+
singleSelectionHeaders,
1865+
rangeSelectionHeaders,
1866+
cardStateMap: {
1867+
card1: {
1868+
rangeSelectionOverride: CardFeatureOverride.OVERRIDE_AS_ENABLED,
1869+
},
1870+
},
1871+
})
1872+
),
1873+
...buildStateFromHparamsState(buildHparamsState(hparamsState)),
1874+
});
1875+
1876+
expect(selectors.getGroupedHeadersForCard('card1')(state)).toEqual([
1877+
{
1878+
type: ColumnHeaderType.RUN,
1879+
name: 'run',
1880+
displayName: 'My Run name',
1881+
enabled: false,
1882+
},
1883+
{
1884+
type: ColumnHeaderType.HPARAM,
1885+
name: 'conv_layers',
1886+
displayName: 'Conv Layers',
1887+
enabled: true,
1888+
},
1889+
{
1890+
type: ColumnHeaderType.HPARAM,
1891+
name: 'conv_kernel_size',
1892+
displayName: 'Conv Kernel Size',
1893+
enabled: true,
1894+
},
1895+
{
1896+
type: ColumnHeaderType.MEAN,
1897+
name: 'mean',
1898+
displayName: 'Mean',
1899+
enabled: true,
1900+
},
1901+
]);
1902+
});
1903+
1904+
it('returns grouped range selection headers when global range selection is enabled', () => {
1905+
const state = buildMockState({
1906+
...appStateFromMetricsState(
1907+
buildMetricsState({
1908+
singleSelectionHeaders,
1909+
rangeSelectionHeaders,
1910+
rangeSelectionEnabled: true,
1911+
})
1912+
),
1913+
...buildStateFromHparamsState(buildHparamsState(hparamsState)),
1914+
});
1915+
1916+
expect(selectors.getGroupedHeadersForCard('card1')(state)).toEqual([
1917+
{
1918+
type: ColumnHeaderType.RUN,
1919+
name: 'run',
1920+
displayName: 'My Run name',
1921+
enabled: false,
1922+
},
1923+
{
1924+
type: ColumnHeaderType.HPARAM,
1925+
name: 'conv_layers',
1926+
displayName: 'Conv Layers',
1927+
enabled: true,
1928+
},
1929+
{
1930+
type: ColumnHeaderType.HPARAM,
1931+
name: 'conv_kernel_size',
1932+
displayName: 'Conv Kernel Size',
1933+
enabled: true,
1934+
},
1935+
{
1936+
type: ColumnHeaderType.MEAN,
1937+
name: 'mean',
1938+
displayName: 'Mean',
1939+
enabled: true,
1940+
},
1941+
]);
1942+
});
1943+
});
17481944
});

tensorboard/webapp/metrics/testing.ts

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,12 @@ import {
2828
TimeSeriesRequest,
2929
TimeSeriesResponse,
3030
} from './data_source';
31+
import * as selectors from './store/metrics_selectors';
3132
import {
3233
MetricsState,
3334
METRICS_FEATURE_KEY,
3435
TagMetadata,
3536
TimeSeriesData,
36-
} from './store';
37-
import * as selectors from './store/metrics_selectors';
38-
import {
3937
CardStepIndexMetaData,
4038
MetricsSettings,
4139
RunToSeries,

0 commit comments

Comments
 (0)