Skip to content

Commit 6c5e2aa

Browse files
authored
Render flexible confusion matrices as expected (#2523)
* concatenate data required for flexible confusion matrix into rev field * refactor fill template * remove unnecessary stroke dash entries from confusion matrix * refactor suppression of encoding elements for confusion matrices * add unit test for get children * include integration test * refactor fill template * update HEAD revision to main
1 parent f69837c commit 6c5e2aa

File tree

7 files changed

+337534
-31
lines changed

7 files changed

+337534
-31
lines changed

extension/src/plots/model/collect.ts

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -501,30 +501,49 @@ export const collectTemplates = (data: PlotsOutput): TemplateAccumulator => {
501501
return acc
502502
}
503503

504+
const updateDatapoints = (
505+
datapoints: unknown[],
506+
key: string,
507+
fields: string[]
508+
): unknown[] =>
509+
datapoints.map(data => {
510+
const obj = data as Record<string, unknown>
511+
return {
512+
...obj,
513+
[key]: mergeFields(fields.map(field => obj[field] as string))
514+
}
515+
})
516+
517+
const stringifyDatapoints = (
518+
datapoints: unknown[],
519+
field: string | undefined,
520+
isMultiView: boolean
521+
): string => {
522+
if (!field || (!isMultiView && !isConcatenatedField(field))) {
523+
return JSON.stringify(datapoints)
524+
}
525+
526+
const fields = unmergeConcatenatedFields(field)
527+
528+
if (isMultiView) {
529+
fields.unshift('rev')
530+
return JSON.stringify(updateDatapoints(datapoints, 'rev', fields))
531+
}
532+
533+
return JSON.stringify(updateDatapoints(datapoints, field, fields))
534+
}
535+
504536
const fillTemplate = (
505537
template: string,
506538
datapoints: unknown[],
507539
field?: string
508-
) => {
509-
if (!field || !isConcatenatedField(field)) {
510-
return JSON.parse(
511-
template.replace('"<DVC_METRIC_DATA>"', JSON.stringify(datapoints))
512-
) as TopLevelSpec
513-
}
540+
): TopLevelSpec => {
541+
const isMultiView = isMultiViewPlot(JSON.parse(template))
514542

515-
const fields = unmergeConcatenatedFields(field)
516543
return JSON.parse(
517544
template.replace(
518545
'"<DVC_METRIC_DATA>"',
519-
JSON.stringify(
520-
datapoints.map(data => {
521-
const obj = data as Record<string, unknown>
522-
return {
523-
...obj,
524-
[field]: mergeFields(fields.map(field => obj[field] as string))
525-
}
526-
})
527-
)
546+
stringifyDatapoints(datapoints, field, isMultiView)
528547
)
529548
) as TopLevelSpec
530549
}

extension/src/plots/paths/model.test.ts

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,4 +165,96 @@ describe('PathsModel', () => {
165165

166166
expect(model.getComparisonPaths()).toStrictEqual(newOrder)
167167
})
168+
169+
it('should return the expected children from the test fixture', () => {
170+
const model = new PathsModel(mockDvcRoot, buildMockMemento())
171+
model.transformAndSet(plotsDiffFixture)
172+
173+
const rootChildren = model.getChildren(undefined, {
174+
'predictions.json': {
175+
strokeDash: { field: '', scale: { domain: [], range: [] } }
176+
}
177+
})
178+
expect(rootChildren).toStrictEqual([
179+
{
180+
descendantStatuses: [2, 2, 2],
181+
hasChildren: true,
182+
label: 'plots',
183+
parentPath: undefined,
184+
path: 'plots',
185+
status: 2
186+
},
187+
{
188+
descendantStatuses: [2, 2],
189+
hasChildren: true,
190+
label: 'logs',
191+
parentPath: undefined,
192+
path: 'logs',
193+
status: 2
194+
},
195+
{
196+
descendantStatuses: [],
197+
hasChildren: false,
198+
label: 'predictions.json',
199+
parentPath: undefined,
200+
path: 'predictions.json',
201+
status: 2,
202+
type: new Set([PathType.TEMPLATE_MULTI])
203+
}
204+
])
205+
206+
const directoryChildren = model.getChildren('logs')
207+
expect(directoryChildren).toStrictEqual([
208+
{
209+
descendantStatuses: [],
210+
hasChildren: false,
211+
label: 'loss.tsv',
212+
parentPath: 'logs',
213+
path: logsLoss,
214+
status: 2,
215+
type: new Set([PathType.TEMPLATE_SINGLE])
216+
},
217+
{
218+
descendantStatuses: [],
219+
hasChildren: false,
220+
label: 'acc.tsv',
221+
parentPath: 'logs',
222+
path: logsAcc,
223+
status: 2,
224+
type: new Set([PathType.TEMPLATE_SINGLE])
225+
}
226+
])
227+
228+
const plotsWithEncoding = model.getChildren('logs', {
229+
[logsAcc]: {
230+
strokeDash: { field: '', scale: { domain: [], range: [] } }
231+
},
232+
[logsLoss]: {
233+
strokeDash: { field: '', scale: { domain: [], range: [] } }
234+
}
235+
})
236+
expect(plotsWithEncoding).toStrictEqual([
237+
{
238+
descendantStatuses: [],
239+
hasChildren: true,
240+
label: 'loss.tsv',
241+
parentPath: 'logs',
242+
path: logsLoss,
243+
status: 2,
244+
type: new Set([PathType.TEMPLATE_SINGLE])
245+
},
246+
{
247+
descendantStatuses: [],
248+
hasChildren: true,
249+
label: 'acc.tsv',
250+
parentPath: 'logs',
251+
path: logsAcc,
252+
status: 2,
253+
type: new Set([PathType.TEMPLATE_SINGLE])
254+
}
255+
])
256+
257+
const noChildren = model.getChildren(logsLoss)
258+
expect(noChildren).toStrictEqual([])
259+
})
168260
})

extension/src/plots/paths/model.ts

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -55,20 +55,13 @@ export class PathsModel extends PathSelectionModel<PlotPath> {
5555
path: string | undefined,
5656
multiSourceEncoding: MultiSourceEncoding = {}
5757
) {
58-
return this.filterChildren(path).map(element => {
59-
const hasChildren =
60-
element.hasChildren === false
61-
? !!multiSourceEncoding[element.path]
62-
: element.hasChildren
63-
64-
return {
65-
...element,
66-
descendantStatuses: this.getTerminalNodeStatuses(element.path),
67-
hasChildren,
68-
label: element.label,
69-
status: this.status[element.path]
70-
}
71-
})
58+
return this.filterChildren(path).map(element => ({
59+
...element,
60+
descendantStatuses: this.getTerminalNodeStatuses(element.path),
61+
hasChildren: this.getHasChildren(element, multiSourceEncoding),
62+
label: element.label,
63+
status: this.status[element.path]
64+
}))
7265
}
7366

7467
public getTemplateOrder(): TemplateOrder {
@@ -116,4 +109,20 @@ export class PathsModel extends PathSelectionModel<PlotPath> {
116109
return element.parentPath === path
117110
})
118111
}
112+
113+
private getHasChildren(
114+
element: PlotPath,
115+
multiSourceEncoding: MultiSourceEncoding
116+
) {
117+
const hasEncodingChildren =
118+
!element.hasChildren &&
119+
!element.type?.has(PathType.TEMPLATE_MULTI) &&
120+
!!multiSourceEncoding[element.path]
121+
122+
if (hasEncodingChildren) {
123+
return true
124+
}
125+
126+
return element.hasChildren
127+
}
119128
}

extension/src/test/fixtures/plotsDiff/index.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,10 @@ export const getOutput = (
448448

449449
export const getMinimalOutput = (): PlotsOutput => ({ ...basicVega })
450450

451+
export const getMultiSourceOutput = (): PlotsOutput => ({
452+
...require('./multiSource').default
453+
})
454+
451455
const expectedRevisions = ['workspace', 'main', '4fb124a', '42b8736', '1ba7bcd']
452456

453457
const extendedSpecs = (plotsOutput: TemplatePlots): TemplatePlotSection[] => {

0 commit comments

Comments
 (0)