Skip to content

Commit d7bdf6a

Browse files
authored
Account for dvc yaml potentially not having a train stage (#2571)
1 parent f107dfa commit d7bdf6a

File tree

3 files changed

+50
-7
lines changed

3 files changed

+50
-7
lines changed

extension/src/experiments/checkpoints/collect.test.ts

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,34 @@ describe('collectHasCheckpoints', () => {
5151
expect(hasCheckpoints).toBe(false)
5252
})
5353

54+
it('should not fail if a train stage is not provided', () => {
55+
const hasCheckpoints = collectHasCheckpoints({
56+
stages: {
57+
extract: {
58+
cmd: 'tar -xzf data/images.tar.gz --directory data',
59+
deps: ['data/images.tar.gz'],
60+
outs: [{ 'data/images/': { cache: false } }]
61+
}
62+
}
63+
} as PartialDvcYaml)
64+
65+
expect(hasCheckpoints).toBe(false)
66+
})
67+
68+
it('should return true if any stage has checkpoints', () => {
69+
const hasCheckpoints = collectHasCheckpoints({
70+
stages: {
71+
extract: {
72+
cmd: 'tar -xzf data/images.tar.gz --directory data',
73+
deps: ['data/images.tar.gz'],
74+
outs: [{ 'data/images/': { cache: false, checkpoint: true } }]
75+
}
76+
}
77+
} as PartialDvcYaml)
78+
79+
expect(hasCheckpoints).toBe(true)
80+
})
81+
5482
it('should correctly classify a more complex dvc.yaml without checkpoint', () => {
5583
const hasCheckpoints = collectHasCheckpoints({
5684
stages: {
Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,22 @@
1-
import { PartialDvcYaml } from '../../fileSystem'
1+
import { Out, PartialDvcYaml } from '../../fileSystem'
22

3-
export const collectHasCheckpoints = (yaml: PartialDvcYaml): boolean => {
4-
return !!yaml.stages.train.outs.some(out => {
3+
const stageHasCheckpoints = (outs: Out[] = []): boolean => {
4+
for (const out of outs) {
55
if (typeof out === 'string') {
6-
return false
6+
continue
77
}
8-
98
if (Object.values(out).some(file => file?.checkpoint)) {
109
return true
1110
}
12-
})
11+
}
12+
return false
13+
}
14+
15+
export const collectHasCheckpoints = (yaml: PartialDvcYaml): boolean => {
16+
for (const stage of Object.values(yaml?.stages || {})) {
17+
if (stageHasCheckpoints(stage?.outs)) {
18+
return true
19+
}
20+
}
21+
return false
1322
}

extension/src/fileSystem/index.ts

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,15 @@ export const isSameOrChild = (root: string, path: string) => {
6767
return !rel.startsWith('..')
6868
}
6969

70+
export type Out =
71+
| string
72+
| Record<string, { checkpoint?: boolean; cache?: boolean }>
73+
7074
export type PartialDvcYaml = {
7175
stages: {
72-
train: { outs: (string | Record<string, { checkpoint?: boolean }>)[] }
76+
[stage: string]: {
77+
outs?: Out[]
78+
}
7379
}
7480
}
7581

0 commit comments

Comments
 (0)