Skip to content

Commit c83e763

Browse files
authored
Patch plots for branches containing path separators (#1949)
* use short sha to fetch HEAD plots data (workaround branch names containing path separators) * refactor
1 parent 4d78b9e commit c83e763

File tree

13 files changed

+173
-68
lines changed

13 files changed

+173
-68
lines changed

extension/src/plots/data/index.ts

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,7 @@ export class PlotsData extends BaseData<{ data: PlotsOutput; revs: string[] }> {
4646
return
4747
}
4848

49-
const args = sameContents(revs, ['workspace']) ? [] : revs
50-
49+
const args = this.getArgs(revs)
5150
const data = await this.internalCommands.executeCommand<PlotsOutput>(
5251
AvailableCommands.PLOTS_DIFF,
5352
this.dvcRoot,
@@ -58,7 +57,7 @@ export class PlotsData extends BaseData<{ data: PlotsOutput; revs: string[] }> {
5857

5958
this.compareFiles(files)
6059

61-
return this.notifyChanged({ data, revs })
60+
return this.notifyChanged({ data, revs: args })
6261
}
6362

6463
public managedUpdate() {
@@ -72,4 +71,15 @@ export class PlotsData extends BaseData<{ data: PlotsOutput; revs: string[] }> {
7271
public setModel(model: PlotsModel) {
7372
this.model = model
7473
}
74+
75+
private getArgs(revs: string[]) {
76+
if (
77+
this.model &&
78+
(sameContents(revs, ['workspace']) || sameContents(revs, []))
79+
) {
80+
return this.model.getDefaultRevs()
81+
}
82+
83+
return revs
84+
}
7585
}

extension/src/plots/model/collect.test.ts

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import {
2222
uniqueValues
2323
} from '../../util/array'
2424
import { TemplatePlot } from '../webview/contract'
25+
import { getCLIBranchId } from '../../test/fixtures/plotsDiff/util'
2526

2627
const logsLossPath = join('logs', 'loss.tsv')
2728

@@ -226,7 +227,17 @@ describe('collectMetricOrder', () => {
226227

227228
describe('collectData', () => {
228229
it('should return the expected output from the test fixture', () => {
229-
const { revisionData, comparisonData } = collectData(plotsDiffFixture)
230+
const mapping = {
231+
'1ba7bcd': '1ba7bcd',
232+
'42b8736': '42b8736',
233+
'4fb124a': '4fb124a',
234+
'53c3851': 'main',
235+
workspace: 'workspace'
236+
}
237+
const { revisionData, comparisonData } = collectData(
238+
plotsDiffFixture,
239+
mapping
240+
)
230241
const revisions = ['workspace', 'main', '42b8736', '1ba7bcd', '4fb124a']
231242

232243
const values =
@@ -237,7 +248,7 @@ describe('collectData', () => {
237248
expect(isEmpty(values)).toBeFalsy()
238249

239250
for (const revision of revisions) {
240-
const expectedValues = values[revision].map(value => ({
251+
const expectedValues = values[getCLIBranchId(revision)].map(value => ({
241252
...value,
242253
rev: revision
243254
}))
@@ -287,7 +298,13 @@ describe('collectTemplates', () => {
287298
})
288299

289300
describe('collectWorkspaceRaceConditionData', () => {
290-
const { comparisonData, revisionData } = collectData(plotsDiffFixture)
301+
const { comparisonData, revisionData } = collectData(plotsDiffFixture, {
302+
'1ba7bcd': '1ba7bcd',
303+
'42b8736': '42b8736',
304+
'4fb124a': '4fb124a',
305+
'53c3851': 'main',
306+
workspace: 'workspace'
307+
})
291308

292309
it('should return no overwrite data if there is no selected checkpoint experiment running in the workspace', () => {
293310
const { overwriteComparisonData, overwriteRevisionData } =

extension/src/plots/model/collect.ts

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -312,31 +312,39 @@ export const collectMetricOrder = (
312312
type RevisionPathData = { [path: string]: Record<string, unknown>[] }
313313

314314
export type RevisionData = {
315-
[revision: string]: RevisionPathData
315+
[label: string]: RevisionPathData
316316
}
317317

318318
export type ComparisonData = {
319-
[revision: string]: {
319+
[label: string]: {
320320
[path: string]: ImagePlot
321321
}
322322
}
323323

324+
export type CLIRevisionIdToLabel = { [shortSha: string]: string }
325+
324326
const collectImageData = (
325327
acc: ComparisonData,
326328
path: string,
327-
plot: ImagePlot
329+
plot: ImagePlot,
330+
cliIdToLabel: CLIRevisionIdToLabel
328331
) => {
329332
const rev = plot.revisions?.[0]
330-
331333
if (!rev) {
332334
return
333335
}
334336

335-
if (!acc[rev]) {
336-
acc[rev] = {}
337+
const label = cliIdToLabel[rev]
338+
339+
if (!label) {
340+
return
341+
}
342+
343+
if (!acc[label]) {
344+
acc[label] = {}
337345
}
338346

339-
acc[rev][path] = plot
347+
acc[label][path] = plot
340348
}
341349

342350
const collectDatapoints = (
@@ -353,15 +361,17 @@ const collectDatapoints = (
353361
const collectPlotData = (
354362
acc: RevisionData,
355363
path: string,
356-
plot: TemplatePlot
364+
plot: TemplatePlot,
365+
cliIdToLabel: CLIRevisionIdToLabel
357366
) => {
358-
for (const rev of plot.revisions || []) {
359-
if (!acc[rev]) {
360-
acc[rev] = {}
367+
for (const id of plot.revisions || []) {
368+
const label = cliIdToLabel[id]
369+
if (!acc[label]) {
370+
acc[label] = {}
361371
}
362-
acc[rev][path] = []
372+
acc[label][path] = []
363373

364-
collectDatapoints(acc, path, rev, plot.datapoints?.[rev])
374+
collectDatapoints(acc, path, label, plot.datapoints?.[id])
365375
}
366376
}
367377

@@ -370,25 +380,33 @@ type DataAccumulator = {
370380
comparisonData: ComparisonData
371381
}
372382

373-
const collectPathData = (acc: DataAccumulator, path: string, plots: Plot[]) => {
383+
const collectPathData = (
384+
acc: DataAccumulator,
385+
path: string,
386+
plots: Plot[],
387+
cliIdToLabel: CLIRevisionIdToLabel
388+
) => {
374389
for (const plot of plots) {
375390
if (isImagePlot(plot)) {
376-
collectImageData(acc.comparisonData, path, plot)
391+
collectImageData(acc.comparisonData, path, plot, cliIdToLabel)
377392
continue
378393
}
379394

380-
collectPlotData(acc.revisionData, path, plot)
395+
collectPlotData(acc.revisionData, path, plot, cliIdToLabel)
381396
}
382397
}
383398

384-
export const collectData = (data: PlotsOutput): DataAccumulator => {
399+
export const collectData = (
400+
data: PlotsOutput,
401+
cliIdToLabel: CLIRevisionIdToLabel
402+
): DataAccumulator => {
385403
const acc = {
386404
comparisonData: {},
387405
revisionData: {}
388406
} as DataAccumulator
389407

390408
for (const [path, plots] of Object.entries(data)) {
391-
collectPathData(acc, path, plots)
409+
collectPathData(acc, path, plots, cliIdToLabel)
392410
}
393411

394412
return acc
@@ -543,7 +561,7 @@ export const collectBranchRevisionDetails = (
543561
const branchRevisions: Record<string, string> = {}
544562
for (const { id, sha } of branchShas) {
545563
if (sha) {
546-
branchRevisions[id] = sha
564+
branchRevisions[id] = shortenForLabel(sha)
547565
}
548566
}
549567
return branchRevisions

extension/src/plots/model/index.ts

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,15 @@ export class PlotsModel extends ModelWithPersistence {
9696
}
9797

9898
public async transformAndSetPlots(data: PlotsOutput, revs: string[]) {
99-
this.fetchedRevs = new Set([...this.fetchedRevs, ...revs])
99+
const cliIdToLabel = this.getCLIIdToLabel()
100+
101+
this.fetchedRevs = new Set([
102+
...this.fetchedRevs,
103+
...revs.map(rev => cliIdToLabel[rev])
104+
])
100105

101106
const [{ comparisonData, revisionData }, templates] = await Promise.all([
102-
collectData(data),
107+
collectData(data, cliIdToLabel),
103108
collectTemplates(data)
104109
])
105110

@@ -168,9 +173,9 @@ export class PlotsModel extends ModelWithPersistence {
168173
...Object.keys(this.revisionData)
169174
])
170175

171-
return this.getSelectedRevisions().filter(
172-
revision => !cachedRevisions.has(revision)
173-
)
176+
return this.getSelectedRevisions()
177+
.filter(label => !cachedRevisions.has(label))
178+
.map(label => this.getCLIId(label))
174179
}
175180

176181
public getMutableRevisions() {
@@ -186,16 +191,20 @@ export class PlotsModel extends ModelWithPersistence {
186191
this.comparisonOrder,
187192
this.experiments
188193
.getSelectedRevisions()
189-
.map(({ label: revision, displayColor, logicalGroupName, id }) => ({
194+
.map(({ label, displayColor, logicalGroupName, id }) => ({
190195
displayColor,
191196
group: logicalGroupName,
192197
id,
193-
revision
198+
revision: label
194199
})),
195200
'revision'
196201
)
197202
}
198203

204+
public getDefaultRevs() {
205+
return ['workspace', ...Object.values(this.branchRevisions)]
206+
}
207+
199208
public getTemplatePlots(order: TemplateOrder | undefined) {
200209
if (!definedAndNonEmpty(order)) {
201210
return
@@ -330,6 +339,20 @@ export class PlotsModel extends ModelWithPersistence {
330339
this.fetchedRevs.delete(id)
331340
}
332341

342+
private getCLIIdToLabel() {
343+
const mapping: { [shortSha: string]: string } = {}
344+
345+
for (const rev of this.getSelectedRevisions()) {
346+
mapping[this.getCLIId(rev)] = rev
347+
}
348+
349+
return mapping
350+
}
351+
352+
private getCLIId(label: string) {
353+
return this.branchRevisions[label] || label
354+
}
355+
333356
private getSelectedRevisions() {
334357
return this.experiments.getSelectedRevisions().map(({ label }) => label)
335358
}

extension/src/test/fixtures/plotsDiff/index.ts

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@ import {
1414
} from '../../../plots/webview/contract'
1515
import { join } from '../../util/path'
1616
import { copyOriginalColors } from '../../../experiments/model/status/colors'
17+
import { getCLIBranchId, replaceBranchCLIId } from './util'
1718

1819
const basicVega = {
1920
[join('logs', 'loss.tsv')]: [
2021
{
2122
type: PlotsType.VEGA,
22-
revisions: ['workspace', 'main', '42b8736', '1ba7bcd', '4fb124a'],
23+
revisions: ['workspace', '53c3851', '42b8736', '1ba7bcd', '4fb124a'],
2324
datapoints: {
2425
workspace: [
2526
{
@@ -68,7 +69,7 @@ const basicVega = {
6869
timestamp: '1641966351758'
6970
}
7071
],
71-
main: [
72+
'53c3851': [
7273
{
7374
loss: '2.298783302307129',
7475
step: '0',
@@ -361,8 +362,8 @@ const getImageData = (baseUrl: string, joinFunc = join) => ({
361362
},
362363
{
363364
type: PlotsType.IMAGE,
364-
revisions: ['main'],
365-
url: joinFunc(baseUrl, 'main_plots_acc.png')
365+
revisions: ['53c3851'],
366+
url: joinFunc(baseUrl, '53c3851_plots_acc.png')
366367
},
367368
{
368369
type: PlotsType.IMAGE,
@@ -388,8 +389,8 @@ const getImageData = (baseUrl: string, joinFunc = join) => ({
388389
},
389390
{
390391
type: PlotsType.IMAGE,
391-
revisions: ['main'],
392-
url: joinFunc(baseUrl, 'main_plots_heatmap.png')
392+
revisions: ['53c3851'],
393+
url: joinFunc(baseUrl, '53c3851_plots_heatmap.png')
393394
},
394395
{
395396
type: PlotsType.IMAGE,
@@ -415,8 +416,8 @@ const getImageData = (baseUrl: string, joinFunc = join) => ({
415416
},
416417
{
417418
type: PlotsType.IMAGE,
418-
revisions: ['main'],
419-
url: joinFunc(baseUrl, 'main_plots_loss.png')
419+
revisions: ['53c3851'],
420+
url: joinFunc(baseUrl, '53c3851_plots_loss.png')
420421
},
421422
{
422423
type: PlotsType.IMAGE,
@@ -468,10 +469,12 @@ const extendedSpecs = (plotsOutput: TemplatePlots): TemplatePlotSection[] => {
468469
data: {
469470
values:
470471
expectedRevisions.flatMap(revision =>
471-
originalPlot.datapoints?.[revision].map(values => ({
472-
...values,
473-
rev: revision
474-
}))
472+
originalPlot.datapoints?.[getCLIBranchId(revision)].map(
473+
values => ({
474+
...values,
475+
rev: revision
476+
})
477+
)
475478
) || []
476479
}
477480
} as TopLevelSpec,
@@ -557,7 +560,7 @@ export const getComparisonWebviewMessage = (
557560
for (const [path, plots] of Object.entries(getImageData(baseUrl, joinFunc))) {
558561
const revisionsAcc: ComparisonRevisionData = {}
559562
for (const { url, revisions } of plots) {
560-
const revision = revisions?.[0]
563+
const revision = replaceBranchCLIId(revisions?.[0])
561564
if (!revision) {
562565
continue
563566
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
export const replaceBranchCLIId = (revision: string): string => {
2+
if (revision === '53c3851') {
3+
return 'main'
4+
}
5+
return revision
6+
}
7+
8+
export const getCLIBranchId = (revision: string): string => {
9+
if (revision === 'main') {
10+
return '53c3851'
11+
}
12+
return revision
13+
}

0 commit comments

Comments
 (0)