Skip to content

Commit 71e3a03

Browse files
committed
CLI和web支持训练恢复功能
1 parent f22ed20 commit 71e3a03

File tree

15 files changed

+281
-224
lines changed

15 files changed

+281
-224
lines changed

darkit/core/trainer.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -332,23 +332,19 @@ def _auto_save_pretrained(self):
332332
"""
333333
根据 Train 的相关参数,控制训练时的自动保存逻辑。
334334
"""
335-
max_step, current_step, save_step_interval = (
336-
self.max_step,
337-
self.current_step,
338-
self.save_step_interval,
339-
)
340-
341-
current_step_idx = current_step + 1
335+
current_step_idx = self.current_step + 1
336+
# 如果设置 save_step_interval 为 0,则不保存 checkpoint
342337
if self.save_step_interval > 0:
338+
# 当当前步数(current_step_idx)为 max_step 或者是 save_step_interval 的倍数时保存模型
343339
if (
344-
current_step_idx == max_step
345-
or current_step_idx % save_step_interval == 0
340+
current_step_idx == self.max_step
341+
or current_step_idx % self.save_step_interval == 0
346342
):
347343
check_poinent = f"iter-{current_step_idx}-ckpt"
348344
self.save_pretrained(check_poinent=check_poinent)
349345
if self.save_directory:
350346
print(
351-
f"Model saved epoch {current_step_idx}/{max_step} at {check_poinent}"
347+
f"Model saved epoch {current_step_idx}/{self.max_step} at {check_poinent}"
352348
)
353349

354350
def _auto_validate(self, val_dataloader):
@@ -407,7 +403,6 @@ def _save_model(self, checkpoint="complete"):
407403
if self.save_directory:
408404
save_path = self.save_directory / f"{checkpoint}.pth"
409405
current_step_idx = self.current_step + 1
410-
print("current_step_idx", current_step_idx)
411406
save_dict = {
412407
"model_class": self.model.__class__.__name__,
413408
"state_dict": self.model,

darkit/core/utils/csv_logger.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
import csv
22
from pathlib import Path
3-
from typing import Any, Collection, Union
3+
from typing import Any, Collection
44
from dataclasses import asdict
55

66

77
class CSVLogger:
8-
def __init__(self, filename: Union[str, Path], fieldnames: Collection[Any]):
8+
def __init__(self, filename: Path, fieldnames: Collection[Any]):
99
self.filename = filename
1010
self.fieldnames = fieldnames
11-
self.file = open(filename, "w")
11+
self.file = open(filename, "a")
1212
self.writer = csv.DictWriter(self.file, fieldnames=fieldnames)
13-
self.writer.writeheader()
13+
if self.file.tell() == 0:
14+
self.writer.writeheader()
1415

1516
def log(self, data):
1617
if hasattr(data, "__dataclass_fields__"):

darkit/core/web/messages/en.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
"startPredict": "Start Predicting",
2727
"cxtLenTooltip": "Context length. The maximum sequence length that the model can consider when processing the input sequence.",
2828
"promptTooltip": "The prompt to use for generation.",
29+
"resumeTooltip": "Resume training from the model checkpoint, Ex: mymodel or mymodel:epoch-1.",
2930
"modelVisualHead": "Select the model for which you want to view details.",
3031
"showDetails": "Show Details",
3132
"mDeleteModel": "Delete Model",

darkit/core/web/messages/zh.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
"startPredict": "开始预测",
2727
"cxtLenTooltip": "上下文长度。模型在处理输入序列时所能考虑的最大序列长度",
2828
"promptTooltip": "输入的提示词,模型将根据提示词生成后续文本",
29+
"resumeTooltip": "从指定模型检查点恢复训练,Ex: mymodel or mymodel:epoch-1",
2930
"modelVisualHead": "选择您想要查看其详细信息的模型。",
3031
"showDetails": "查看详情",
3132
"mDeleteModel": "删除模型",

darkit/core/web/src/lib/apis/lm.ts

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { get } from '$lib/api';
1+
import { get, post } from '$lib/api';
22

33
export type ModelLabel = {
44
default: string | number;
@@ -22,6 +22,7 @@ export interface TrainedModel {
2222
config: ExternalConfig;
2323
}
2424

25+
/** 获取以及训练完成的模型 */
2526
export const getTrainedNetworks = async () => {
2627
const res = await get('/lm/models');
2728
return res.json() as Promise<TrainedModel[]>;
@@ -36,3 +37,69 @@ export const getModelDetail = async (model: string) => {
3637
const res = await get(`/lm/models/${model}`);
3738
return res.json() as Promise<TrainedModel>;
3839
};
40+
41+
export const getModelsOptions = async () => {
42+
const res = await get('/lm/models/options');
43+
return res.json() as Promise<{ [key: string]: { model: any; trainer: any } }>;
44+
};
45+
46+
export const getTrainResources = async () => {
47+
const res = await get('/lm/train/resources');
48+
return res.json() as Promise<string[][]>;
49+
};
50+
51+
export const getTrainCommand = async (
52+
type: string,
53+
fork: string | null,
54+
resume: string | null,
55+
dataset: string,
56+
tokenizer: string,
57+
modelOption: any,
58+
trainerOption: any
59+
) => {
60+
const res = await post(`/lm/model/train/command`, {
61+
type: type,
62+
fork: fork,
63+
resume: resume,
64+
dataset: dataset,
65+
tokenizer: tokenizer,
66+
m_conf: modelOption,
67+
t_conf: trainerOption
68+
});
69+
return res.text() as Promise<string>;
70+
};
71+
72+
export const startTrainModel = async (
73+
type: string,
74+
fork: string | null,
75+
resume: string | null,
76+
dataset: string,
77+
tokenizer: string,
78+
modelOption: any,
79+
trainerOption: any
80+
) => {
81+
const res = await post('/lm/v2/model/train/', {
82+
type: type,
83+
fork: fork,
84+
resume: resume,
85+
dataset: dataset,
86+
tokenizer: tokenizer,
87+
m_conf: modelOption,
88+
t_conf: trainerOption
89+
});
90+
return res.json() as Promise<any>;
91+
};
92+
93+
export const createForkNetwork = async (
94+
name: string,
95+
modelType: string,
96+
modelOption: any,
97+
trainerOption: any
98+
) => {
99+
const res = await post(`/lm/edit/init/${name}`, {
100+
model: modelType,
101+
m_conf: modelOption,
102+
t_conf: trainerOption
103+
});
104+
return res.json() as Promise<any>;
105+
};

darkit/core/web/src/routes/lm/train/+page.svelte

Lines changed: 59 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,111 +1,75 @@
11
<script lang="ts">
2-
import { onMount } from 'svelte';
32
import { toast } from 'svelte-sonner';
43
import { gotoWithi18n } from '$lib/i18n';
5-
import Select from '$lib/components/select.svelte';
64
import * as Alert from '$lib/components/ui/alert';
5+
import Select from '$lib/components/select.svelte';
76
import ConfigForm from '$lib/components/config-form.svelte';
8-
import Button, { buttonVariants } from '$lib/components/ui/button/button.svelte';
7+
import { Button, buttonVariants } from '$lib/components/ui/button';
98
import Code from '$lib/components/code.svelte';
109
import * as m from '$lib/paraglide/messages';
11-
import { get, post } from '../api';
10+
import {
11+
createForkNetwork as createForkNetworkApi,
12+
getTrainCommand,
13+
startTrainModel
14+
} from '$lib/apis/lm';
1215
import * as Dialog from '$lib/components/ui/dialog';
1316
import Input from '$lib/components/input.svelte';
1417
15-
let modelType = $state('');
16-
let modelsOptions = $state<{ [key: string]: { model: any; trainer: any } } | null>(null);
18+
let { data } = $props();
1719
18-
let dataset = $state('');
19-
let tokenizer = $state('');
20-
let datasetList = $state<string[]>([]);
21-
let tokenizerList = $state<string[]>([]);
20+
let modelType = $state(Object.keys(data.modelsOptions)[0]);
21+
22+
let dataset = $state(data.resources[0][0]);
23+
let tokenizer = $state(data.resources[1][0]);
24+
25+
let fork = $state<string | null>(null);
26+
let resume = $state('');
2227
23-
let modelKeys = $derived(
24-
modelsOptions ? Object.keys(modelsOptions).map((item) => ({ value: item, label: item })) : []
25-
);
2628
let modelOption = $state<any>({});
2729
let trainerOption = $state<any>({});
28-
let modelConfig = $derived(modelsOptions ? modelsOptions[modelType]?.model : null);
29-
let trainerConfig = $derived(modelsOptions ? modelsOptions[modelType]?.trainer : null);
30-
31-
let forkName = $state<string | null>(null);
30+
let modelKeys = $derived(
31+
data.modelsOptions
32+
? Object.keys(data.modelsOptions).map((item) => ({ value: item, label: item }))
33+
: []
34+
);
35+
let resumeList = $derived(
36+
data.trainedNetworks ? data.trainedNetworks.map((v) => ({ value: v.name, label: v.name })) : []
37+
);
38+
let datasetList = $derived(data.resources[0].map((item) => ({ value: item, label: item })));
39+
let tokenizerList = $derived(data.resources[1].map((item) => ({ value: item, label: item })));
40+
let modelConfig = $derived(data.modelsOptions ? data.modelsOptions[modelType]?.model : null);
41+
let trainerConfig = $derived(data.modelsOptions ? data.modelsOptions[modelType]?.trainer : null);
3242
33-
let command = $derived.by(() => {
34-
const baseCommand = `darkit lm train --tokenizer ${tokenizer} --dataset ${dataset} ${modelType}`;
35-
return `${baseCommand} ${generateCommand(modelOption)} ${generateCommand(trainerOption)}`;
43+
let command = $derived.by(async () => {
44+
return getTrainCommand(modelType, fork, resume, dataset, tokenizer, modelOption, trainerOption);
3645
});
3746
38-
async function startTrain(command: string) {
47+
async function startTrain() {
3948
try {
40-
const res = await post(`/model/${modelType}/train`, {
41-
command: command
42-
});
43-
const data = await res.json();
44-
if (res.status !== 200) {
45-
toast.error(data.detail);
46-
return;
47-
}
49+
const type = modelType;
50+
await startTrainModel(type, fork, resume, dataset, tokenizer, modelOption, trainerOption);
4851
gotoWithi18n(`/lm/visual/${trainerOption.name}`);
4952
} catch (e) {
5053
console.error(e);
54+
toast.error('Failed to start training');
5155
}
5256
}
5357
5458
async function createForkNetwork() {
55-
try {
56-
const res = await post(`/edit/init/${forkName}`, {
57-
model: modelType,
58-
m_conf: modelOption,
59-
t_conf: trainerOption
60-
});
61-
const data = await res.json();
62-
if (res.status !== 200) {
63-
toast.error(data.detail);
64-
return;
59+
if (fork) {
60+
try {
61+
await createForkNetworkApi(fork, modelType, modelOption, trainerOption);
62+
gotoWithi18n(`/lm/fork/${fork}`);
63+
} catch (e) {
64+
console.error(e);
65+
toast.error('Failed to create fork network');
6566
}
66-
67-
gotoWithi18n(`/lm/fork/${forkName}`);
68-
} catch (e) {
69-
console.error(e);
7067
}
7168
}
72-
73-
// Generate command
74-
function generateCommand(config: any) {
75-
return Object.entries(config)
76-
.filter(([key, val]) => {
77-
return val !== null && val !== '' && val !== undefined;
78-
})
79-
.map(([key, val]) => `--${key} ${val}`)
80-
.join(' ');
81-
}
82-
83-
onMount(async () => {
84-
const [modelsOptionsJson, resources] = await Promise.all([
85-
get('/models/options').then((res) => res.json()),
86-
get('/train/resources').then((res) => res.json())
87-
]);
88-
89-
modelsOptions = modelsOptionsJson;
90-
if (modelsOptions) {
91-
if (modelType === '' || !Object.keys(modelsOptions).includes(modelType)) {
92-
modelType = Object.keys(modelsOptions)[0];
93-
}
94-
}
95-
if (resources) {
96-
[datasetList, tokenizerList] = resources;
97-
if (dataset === '' || !datasetList.includes(dataset)) {
98-
dataset = datasetList ? datasetList[0] : '';
99-
}
100-
if (tokenizer === '' || !tokenizerList.includes(tokenizer)) {
101-
tokenizer = tokenizerList ? tokenizerList[0] : '';
102-
}
103-
}
104-
});
10569
</script>
10670

10771
<div class="h-full flex-1 overflow-y-auto overflow-x-hidden p-8">
108-
{#if modelsOptions}
72+
{#if data.modelsOptions}
10973
<Alert.Root class="text-primary/60 [&:has(svg)]:pl-4">
11074
<Alert.Description>
11175
<span class="text-xl"> 🍾 </span>
@@ -120,30 +84,36 @@
12084
class="flex-1"
12185
label="Tokenizers"
12286
bind:value={tokenizer}
123-
options={tokenizerList.map((v) => ({ value: v, label: v }))}
87+
options={tokenizerList}
12488
tooltip={m.tokenizersTooltip()}
12589
/>
90+
<Select
91+
class="flex-1"
92+
label="Resume"
93+
bind:value={resume}
94+
options={resumeList}
95+
tooltip={m.resumeTooltip()}
96+
/>
12697
</div>
12798
<ConfigForm key={modelType} config={modelConfig} bind:option={modelOption} />
12899
</section>
129100
<section class="py-4">
130101
<h3 class="mb-2 font-bold">Trainer Config</h3>
131-
132-
<Select
133-
label={m.dataset()}
134-
class="mb-4 w-96"
135-
bind:value={dataset}
136-
options={datasetList.map((v) => ({ value: v, label: v }))}
137-
/>
138-
102+
<Select label={m.dataset()} class="mb-4 w-96" bind:value={dataset} options={datasetList} />
139103
<ConfigForm key={modelType} config={trainerConfig} bind:option={trainerOption} />
140104
</section>
141105

142106
<h3 class="mb-2 mt-8 scroll-m-20 text-2xl font-semibold tracking-tight">
143107
{m.generatedCommand()}
144108
</h3>
145109

146-
<Code class="mt-8 max-h-96" lang="bash" content={command} wrap />
110+
{#await command}
111+
<Code class="mt-8 max-h-96" lang="bash" content="Updating..." wrap />
112+
{:then command}
113+
<Code class="mt-8 max-h-96" lang="bash" content={command} wrap />
114+
{:catch error}
115+
<Code class="mt-8 max-h-96" lang="bash" content={error.detail} wrap />
116+
{/await}
147117

148118
<div class="w-full text-right">
149119
<Dialog.Root>
@@ -163,15 +133,15 @@
163133
placeholder="Enter a name for the fork"
164134
class="w-full"
165135
variant="row"
166-
bind:value={forkName}
136+
bind:value={fork}
167137
/>
168138
</div>
169139
<Dialog.Footer>
170140
<Button type="submit" onclick={createForkNetwork}>Create fork</Button>
171141
</Dialog.Footer>
172142
</Dialog.Content>
173143
</Dialog.Root>
174-
<Button size="lg" class="mt-8" onclick={() => startTrain(command)}>
144+
<Button size="lg" class="mt-8" onclick={startTrain}>
175145
{m.startTrain()}
176146
</Button>
177147
</div>
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import { getModelsOptions, getTrainResources, getTrainedNetworks } from '$lib/apis/lm';
2+
3+
export const load = async () => {
4+
const resources = await getTrainResources();
5+
const modelsOptions = await getModelsOptions();
6+
const trainedNetworks = await getTrainedNetworks();
7+
8+
return {
9+
resources,
10+
modelsOptions,
11+
trainedNetworks
12+
};
13+
};

darkit/core/web/src/routes/lm/visual/[models]/+page.svelte

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import type { ModelDetail } from '$lib/apis/lm';
33
</script>
44

5+
<!-- #TODO: RESUME 适配 -->
56
<script lang="ts">
67
import { toast } from 'svelte-sonner';
78
import { onDestroy, onMount } from 'svelte';

0 commit comments

Comments
 (0)