Skip to content

Commit 8ce1a4c

Browse files
authored
Fix size of flexible confusion matrix (#2531)
* fix size of multi source confusion matrix * refactor transform revision data
1 parent 6c5e2aa commit 8ce1a4c

File tree

3 files changed

+109
-59
lines changed

3 files changed

+109
-59
lines changed

extension/src/experiments/model/index.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ import { sum } from '../../util/math'
3939

4040
export type StarredExperiments = Record<string, boolean | undefined>
4141

42-
type SelectedExperimentWithColor = Experiment & {
42+
export type SelectedExperimentWithColor = Experiment & {
4343
displayColor: Color
4444
selected: true
4545
}

extension/src/plots/model/collect.ts

Lines changed: 85 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ import {
4040
MultiSourceEncoding,
4141
unmergeConcatenatedFields
4242
} from '../multiSource/collect'
43+
import { StrokeDashEncoding } from '../multiSource/constants'
4344

4445
type CheckpointPlotAccumulator = {
4546
iterations: Record<string, number>
@@ -502,49 +503,95 @@ export const collectTemplates = (data: PlotsOutput): TemplateAccumulator => {
502503
}
503504

504505
const updateDatapoints = (
505-
datapoints: unknown[],
506+
path: string,
507+
revisionData: RevisionData,
508+
selectedRevisions: string[],
506509
key: string,
507510
fields: string[]
508511
): unknown[] =>
509-
datapoints.map(data => {
510-
const obj = data as Record<string, unknown>
511-
return {
512-
...obj,
513-
[key]: mergeFields(fields.map(field => obj[field] as string))
512+
selectedRevisions
513+
.flatMap(revision =>
514+
revisionData?.[revision]?.[path].map(data => {
515+
const obj = data as Record<string, unknown>
516+
return {
517+
...obj,
518+
[key]: mergeFields(fields.map(field => obj[field] as string))
519+
}
520+
})
521+
)
522+
.filter(Boolean)
523+
524+
const updateRevisions = (
525+
selectedRevisions: string[],
526+
domain: string[]
527+
): string[] => {
528+
const revisions: string[] = []
529+
for (const revision of selectedRevisions) {
530+
for (const entry of domain) {
531+
revisions.push(mergeFields([revision, entry]))
514532
}
515-
})
533+
}
534+
return revisions
535+
}
516536

517-
const stringifyDatapoints = (
518-
datapoints: unknown[],
519-
field: string | undefined,
520-
isMultiView: boolean
521-
): string => {
522-
if (!field || (!isMultiView && !isConcatenatedField(field))) {
523-
return JSON.stringify(datapoints)
537+
const transformRevisionData = (
538+
path: string,
539+
selectedRevisions: string[],
540+
revisionData: RevisionData,
541+
isMultiView: boolean,
542+
multiSourceEncodingUpdate: { strokeDash: StrokeDashEncoding }
543+
): { revisions: string[]; datapoints: unknown[] } => {
544+
const field = multiSourceEncodingUpdate.strokeDash?.field
545+
const isMultiSource = !!field
546+
547+
const transformNeeded =
548+
isMultiSource && (isMultiView || isConcatenatedField(field))
549+
550+
if (!transformNeeded) {
551+
return {
552+
datapoints: selectedRevisions
553+
.flatMap(revision => revisionData?.[revision]?.[path])
554+
.filter(Boolean),
555+
revisions: selectedRevisions
556+
}
524557
}
525558

526559
const fields = unmergeConcatenatedFields(field)
527-
528560
if (isMultiView) {
529561
fields.unshift('rev')
530-
return JSON.stringify(updateDatapoints(datapoints, 'rev', fields))
562+
return {
563+
datapoints: updateDatapoints(
564+
path,
565+
revisionData,
566+
selectedRevisions,
567+
'rev',
568+
fields
569+
),
570+
revisions: updateRevisions(
571+
selectedRevisions,
572+
multiSourceEncodingUpdate.strokeDash.scale.domain
573+
)
574+
}
531575
}
532576

533-
return JSON.stringify(updateDatapoints(datapoints, field, fields))
577+
return {
578+
datapoints: updateDatapoints(
579+
path,
580+
revisionData,
581+
selectedRevisions,
582+
field,
583+
fields
584+
),
585+
revisions: selectedRevisions
586+
}
534587
}
535588

536589
const fillTemplate = (
537590
template: string,
538-
datapoints: unknown[],
539-
field?: string
591+
datapoints: unknown[]
540592
): TopLevelSpec => {
541-
const isMultiView = isMultiViewPlot(JSON.parse(template))
542-
543593
return JSON.parse(
544-
template.replace(
545-
'"<DVC_METRIC_DATA>"',
546-
stringifyDatapoints(datapoints, field, isMultiView)
547-
)
594+
template.replace('"<DVC_METRIC_DATA>"', JSON.stringify(datapoints))
548595
) as TopLevelSpec
549596
}
550597

@@ -562,30 +609,26 @@ const collectTemplateGroup = (
562609
const template = templates[path]
563610

564611
if (template) {
565-
const datapoints = selectedRevisions
566-
.flatMap(revision => revisionData?.[revision]?.[path])
567-
.filter(Boolean)
568-
612+
const isMultiView = isMultiViewPlot(JSON.parse(template))
569613
const multiSourceEncodingUpdate = multiSourceEncoding[path] || {}
570-
571-
const content = extendVegaSpec(
572-
fillTemplate(
573-
template,
574-
datapoints,
575-
multiSourceEncodingUpdate.strokeDash?.field
576-
),
577-
size,
578-
{
579-
...multiSourceEncodingUpdate,
580-
color: revisionColors
581-
}
614+
const { datapoints, revisions } = transformRevisionData(
615+
path,
616+
selectedRevisions,
617+
revisionData,
618+
isMultiView,
619+
multiSourceEncodingUpdate
582620
)
583621

622+
const content = extendVegaSpec(fillTemplate(template, datapoints), size, {
623+
...multiSourceEncodingUpdate,
624+
color: revisionColors
625+
})
626+
584627
acc.push({
585628
content,
586629
id: path,
587630
multiView: isMultiViewPlot(content),
588-
revisions: selectedRevisions,
631+
revisions,
589632
type: PlotsType.VEGA
590633
})
591634
}

extension/src/test/suite/plots/index.test.ts

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ import { MessageFromWebviewType } from '../../../webview/contract'
3737
import { reorderObjectList } from '../../../util/array'
3838
import * as Telemetry from '../../../telemetry'
3939
import { EventName } from '../../../telemetry/constants'
40+
import { SelectedExperimentWithColor } from '../../../experiments/model'
4041

4142
suite('Plots Test Suite', () => {
4243
const disposable = Disposable.fn()
@@ -761,10 +762,13 @@ suite('Plots Test Suite', () => {
761762
}).timeout(WEBVIEW_TEST_TIMEOUT)
762763

763764
it('should send the correct data to the webview for flexible plots', async () => {
764-
const { plots, messageSpy, mockPlotsDiff } = await buildPlots(
765-
disposable,
766-
multiSourcePlotsDiffFixture
767-
)
765+
const { plots, messageSpy, mockPlotsDiff, experiments } =
766+
await buildPlots(disposable, multiSourcePlotsDiffFixture)
767+
768+
stub(experiments, 'getSelectedRevisions').returns([
769+
{ label: 'workspace' },
770+
{ label: 'main' }
771+
] as SelectedExperimentWithColor[])
768772

769773
const webview = await plots.showWebview()
770774
await webview.isReady()
@@ -798,17 +802,6 @@ suite('Plots Test Suite', () => {
798802
multiViewSection.entries.map(({ id }: { id: string }) => id)
799803
).to.deep.equal(['dvc.yaml::Confusion-Matrix'])
800804

801-
const [confusionMatrix] = multiViewSection.entries
802-
803-
const confusionMatrixDatapoints =
804-
(
805-
confusionMatrix.content.data as {
806-
values: { rev: string }[]
807-
}
808-
)?.values || []
809-
810-
expect(confusionMatrixDatapoints.length).to.be.greaterThan(0)
811-
812805
const expectedRevisions = [
813806
`main::${join('evaluation', 'test', 'plots', 'confusion_matrix.json')}`,
814807
`workspace::${join(
@@ -829,7 +822,21 @@ suite('Plots Test Suite', () => {
829822
'plots',
830823
'confusion_matrix.json'
831824
)}`
832-
]
825+
].sort()
826+
827+
const [confusionMatrix] = multiViewSection.entries
828+
829+
const confusionMatrixDatapoints =
830+
(
831+
confusionMatrix.content.data as {
832+
values: { rev: string }[]
833+
}
834+
)?.values || []
835+
836+
expect(confusionMatrixDatapoints.length).to.be.greaterThan(0)
837+
838+
expect(confusionMatrix.revisions?.length).to.equal(4)
839+
expect(confusionMatrix.revisions?.sort()).to.deep.equal(expectedRevisions)
833840

834841
for (const entry of confusionMatrixDatapoints) {
835842
expect(expectedRevisions).to.include(entry.rev)

0 commit comments

Comments
 (0)