Skip to content

Commit afff55a

Browse files
authored
Ensure custom plots legend matches plot values (#4729)
1 parent 65204e3 commit afff55a

File tree

2 files changed

+75
-32
lines changed

2 files changed

+75
-32
lines changed

extension/src/plots/model/collect.test.ts

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,34 @@ describe('collectCustomPlots', () => {
6666
})
6767
expect(data[0].values.slice(-1)[0].id).toStrictEqual('main')
6868
})
69+
70+
it('should create custom plot scales that match the collected values', () => {
71+
const expectedOutput: CustomPlotData[] = customPlotsFixture.plots
72+
const data = collectCustomPlots({
73+
colorScale: {
74+
domain: ['main', 'exp-e7a67', 'test-branch', 'exp-83425', 'failed-exp'],
75+
range: ['#13adc7', '#f46837', '#48bb78', '#4299e1', '#f56565']
76+
},
77+
experiments: [
78+
...experimentsWithCommits,
79+
{
80+
branch: 'main',
81+
id: 'weird-exp',
82+
label: 'exp with no metrics or params'
83+
},
84+
{
85+
branch: 'main',
86+
error: 'failed to run',
87+
id: 'failed-exp',
88+
label: '123'
89+
}
90+
],
91+
height: DEFAULT_PLOT_HEIGHT,
92+
nbItemsPerRow: DEFAULT_NB_ITEMS_PER_ROW,
93+
plotsOrderValues: customPlotsOrderFixture
94+
})
95+
expect(data).toStrictEqual(expectedOutput)
96+
})
6997
})
7098

7199
describe('collectData', () => {

extension/src/plots/model/collect.ts

Lines changed: 47 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -80,19 +80,63 @@ const getValues = (
8080
return values
8181
}
8282

83+
const removeSelectedExperiment = (
84+
orderedColorScale: ColorScale,
85+
hasValues: boolean,
86+
idx: number
87+
) => {
88+
const isSelectedExperiment = idx !== -1
89+
if (!isSelectedExperiment || hasValues) {
90+
return
91+
}
92+
93+
orderedColorScale.domain.splice(idx, 1)
94+
orderedColorScale.range.splice(idx, 1)
95+
}
96+
97+
const fillColorScale = (
98+
experiments: Experiment[],
99+
colorScale: ColorScale | undefined,
100+
valueIds: Set<string>
101+
) => {
102+
const orderedColorScale = {
103+
domain: [...(colorScale?.domain || [])],
104+
range: [...(colorScale?.range || [])]
105+
}
106+
107+
for (const experiment of experiments) {
108+
const { id } = experiment
109+
const idx = orderedColorScale.domain.indexOf(id)
110+
const isSelectedExperiment = idx !== -1
111+
const hasValues = valueIds.has(id)
112+
113+
if (!hasValues || isSelectedExperiment) {
114+
removeSelectedExperiment(orderedColorScale, hasValues, idx)
115+
continue
116+
}
117+
118+
orderedColorScale.domain.push(id)
119+
orderedColorScale.range.push('#4c78a8' as Color)
120+
}
121+
122+
return orderedColorScale
123+
}
124+
83125
const getCustomPlotData = (
84126
orderValue: CustomPlotsOrderValue,
85127
experiments: Experiment[],
86128
height: number,
87129
nbItemsPerRow: number,
88-
completeColorScale: ColorScale,
89-
renderLastIds: Set<string>
130+
colorScale: ColorScale | undefined
90131
): CustomPlotData => {
91132
const { metric, param } = orderValue
92133
const metricPath = getFullValuePath(ColumnType.METRICS, metric)
93134
const paramPath = getFullValuePath(ColumnType.PARAMS, param)
94135

136+
const renderLastIds = new Set(colorScale?.domain)
95137
const values = getValues(experiments, metricPath, paramPath, renderLastIds)
138+
const valueIds = new Set(values.map(({ id }) => id))
139+
const completeColorScale = fillColorScale(experiments, colorScale, valueIds)
96140

97141
const [{ param: paramVal, metric: metricVal }] = values
98142
const yTitle = truncateVerticalTitle(metric, nbItemsPerRow, height) as string
@@ -115,26 +159,6 @@ const getCustomPlotData = (
115159
} as CustomPlotData
116160
}
117161

118-
const fillColorScale = (
119-
colorScale: ColorScale | undefined,
120-
experiments: Experiment[]
121-
) => {
122-
const completeColorScale = {
123-
domain: [...(colorScale?.domain || [])],
124-
range: [...(colorScale?.range || [])]
125-
}
126-
127-
for (const experiment of experiments) {
128-
const { id } = experiment
129-
if (completeColorScale.domain.includes(id)) {
130-
continue
131-
}
132-
completeColorScale.domain.push(id)
133-
completeColorScale.range.push('#4c78a8' as Color)
134-
}
135-
return completeColorScale
136-
}
137-
138162
export const collectCustomPlots = ({
139163
colorScale,
140164
plotsOrderValues,
@@ -150,18 +174,9 @@ export const collectCustomPlots = ({
150174
}): CustomPlotData[] => {
151175
const plots = []
152176

153-
const completeColorScale = fillColorScale(colorScale, experiments)
154-
155177
for (const value of plotsOrderValues) {
156178
plots.push(
157-
getCustomPlotData(
158-
value,
159-
experiments,
160-
height,
161-
nbItemsPerRow,
162-
completeColorScale,
163-
new Set(colorScale?.domain)
164-
)
179+
getCustomPlotData(value, experiments, height, nbItemsPerRow, colorScale)
165180
)
166181
}
167182

0 commit comments

Comments
 (0)