|
1 | 1 | <script lang="ts">
|
2 |
| - import { onMount } from 'svelte'; |
3 | 2 | import { toast } from 'svelte-sonner';
|
4 | 3 | import { gotoWithi18n } from '$lib/i18n';
|
5 |
| - import Select from '$lib/components/select.svelte'; |
6 | 4 | import * as Alert from '$lib/components/ui/alert';
|
| 5 | + import Select from '$lib/components/select.svelte'; |
7 | 6 | import ConfigForm from '$lib/components/config-form.svelte';
|
8 |
| - import Button, { buttonVariants } from '$lib/components/ui/button/button.svelte'; |
| 7 | + import { Button, buttonVariants } from '$lib/components/ui/button'; |
9 | 8 | import Code from '$lib/components/code.svelte';
|
10 | 9 | import * as m from '$lib/paraglide/messages';
|
11 |
| - import { get, post } from '../api'; |
| 10 | + import { |
| 11 | + createForkNetwork as createForkNetworkApi, |
| 12 | + getTrainCommand, |
| 13 | + startTrainModel |
| 14 | + } from '$lib/apis/lm'; |
12 | 15 | import * as Dialog from '$lib/components/ui/dialog';
|
13 | 16 | import Input from '$lib/components/input.svelte';
|
14 | 17 |
|
15 |
| - let modelType = $state(''); |
16 |
| - let modelsOptions = $state<{ [key: string]: { model: any; trainer: any } } | null>(null); |
| 18 | + let { data } = $props(); |
17 | 19 |
|
18 |
| - let dataset = $state(''); |
19 |
| - let tokenizer = $state(''); |
20 |
| - let datasetList = $state<string[]>([]); |
21 |
| - let tokenizerList = $state<string[]>([]); |
| 20 | + let modelType = $state(Object.keys(data.modelsOptions)[0]); |
| 21 | +
|
| 22 | + let dataset = $state(data.resources[0][0]); |
| 23 | + let tokenizer = $state(data.resources[1][0]); |
| 24 | +
|
| 25 | + let fork = $state<string | null>(null); |
| 26 | + let resume = $state(''); |
22 | 27 |
|
23 |
| - let modelKeys = $derived( |
24 |
| - modelsOptions ? Object.keys(modelsOptions).map((item) => ({ value: item, label: item })) : [] |
25 |
| - ); |
26 | 28 | let modelOption = $state<any>({});
|
27 | 29 | let trainerOption = $state<any>({});
|
28 |
| - let modelConfig = $derived(modelsOptions ? modelsOptions[modelType]?.model : null); |
29 |
| - let trainerConfig = $derived(modelsOptions ? modelsOptions[modelType]?.trainer : null); |
30 |
| -
|
31 |
| - let forkName = $state<string | null>(null); |
| 30 | + let modelKeys = $derived( |
| 31 | + data.modelsOptions |
| 32 | + ? Object.keys(data.modelsOptions).map((item) => ({ value: item, label: item })) |
| 33 | + : [] |
| 34 | + ); |
| 35 | + let resumeList = $derived( |
| 36 | + data.trainedNetworks ? data.trainedNetworks.map((v) => ({ value: v.name, label: v.name })) : [] |
| 37 | + ); |
| 38 | + let datasetList = $derived(data.resources[0].map((item) => ({ value: item, label: item }))); |
| 39 | + let tokenizerList = $derived(data.resources[1].map((item) => ({ value: item, label: item }))); |
| 40 | + let modelConfig = $derived(data.modelsOptions ? data.modelsOptions[modelType]?.model : null); |
| 41 | + let trainerConfig = $derived(data.modelsOptions ? data.modelsOptions[modelType]?.trainer : null); |
32 | 42 |
|
33 |
| - let command = $derived.by(() => { |
34 |
| - const baseCommand = `darkit lm train --tokenizer ${tokenizer} --dataset ${dataset} ${modelType}`; |
35 |
| - return `${baseCommand} ${generateCommand(modelOption)} ${generateCommand(trainerOption)}`; |
| 43 | + let command = $derived.by(async () => { |
| 44 | + return getTrainCommand(modelType, fork, resume, dataset, tokenizer, modelOption, trainerOption); |
36 | 45 | });
|
37 | 46 |
|
38 |
| - async function startTrain(command: string) { |
| 47 | + async function startTrain() { |
39 | 48 | try {
|
40 |
| - const res = await post(`/model/${modelType}/train`, { |
41 |
| - command: command |
42 |
| - }); |
43 |
| - const data = await res.json(); |
44 |
| - if (res.status !== 200) { |
45 |
| - toast.error(data.detail); |
46 |
| - return; |
47 |
| - } |
| 49 | + const type = modelType; |
| 50 | + await startTrainModel(type, fork, resume, dataset, tokenizer, modelOption, trainerOption); |
48 | 51 | gotoWithi18n(`/lm/visual/${trainerOption.name}`);
|
49 | 52 | } catch (e) {
|
50 | 53 | console.error(e);
|
| 54 | + toast.error('Failed to start training'); |
51 | 55 | }
|
52 | 56 | }
|
53 | 57 |
|
54 | 58 | async function createForkNetwork() {
|
55 |
| - try { |
56 |
| - const res = await post(`/edit/init/${forkName}`, { |
57 |
| - model: modelType, |
58 |
| - m_conf: modelOption, |
59 |
| - t_conf: trainerOption |
60 |
| - }); |
61 |
| - const data = await res.json(); |
62 |
| - if (res.status !== 200) { |
63 |
| - toast.error(data.detail); |
64 |
| - return; |
| 59 | + if (fork) { |
| 60 | + try { |
| 61 | + await createForkNetworkApi(fork, modelType, modelOption, trainerOption); |
| 62 | + gotoWithi18n(`/lm/fork/${fork}`); |
| 63 | + } catch (e) { |
| 64 | + console.error(e); |
| 65 | + toast.error('Failed to create fork network'); |
65 | 66 | }
|
66 |
| -
|
67 |
| - gotoWithi18n(`/lm/fork/${forkName}`); |
68 |
| - } catch (e) { |
69 |
| - console.error(e); |
70 | 67 | }
|
71 | 68 | }
|
72 |
| -
|
73 |
| - // Generate command |
74 |
| - function generateCommand(config: any) { |
75 |
| - return Object.entries(config) |
76 |
| - .filter(([key, val]) => { |
77 |
| - return val !== null && val !== '' && val !== undefined; |
78 |
| - }) |
79 |
| - .map(([key, val]) => `--${key} ${val}`) |
80 |
| - .join(' '); |
81 |
| - } |
82 |
| -
|
83 |
| - onMount(async () => { |
84 |
| - const [modelsOptionsJson, resources] = await Promise.all([ |
85 |
| - get('/models/options').then((res) => res.json()), |
86 |
| - get('/train/resources').then((res) => res.json()) |
87 |
| - ]); |
88 |
| -
|
89 |
| - modelsOptions = modelsOptionsJson; |
90 |
| - if (modelsOptions) { |
91 |
| - if (modelType === '' || !Object.keys(modelsOptions).includes(modelType)) { |
92 |
| - modelType = Object.keys(modelsOptions)[0]; |
93 |
| - } |
94 |
| - } |
95 |
| - if (resources) { |
96 |
| - [datasetList, tokenizerList] = resources; |
97 |
| - if (dataset === '' || !datasetList.includes(dataset)) { |
98 |
| - dataset = datasetList ? datasetList[0] : ''; |
99 |
| - } |
100 |
| - if (tokenizer === '' || !tokenizerList.includes(tokenizer)) { |
101 |
| - tokenizer = tokenizerList ? tokenizerList[0] : ''; |
102 |
| - } |
103 |
| - } |
104 |
| - }); |
105 | 69 | </script>
|
106 | 70 |
|
107 | 71 | <div class="h-full flex-1 overflow-y-auto overflow-x-hidden p-8">
|
108 |
| - {#if modelsOptions} |
| 72 | + {#if data.modelsOptions} |
109 | 73 | <Alert.Root class="text-primary/60 [&:has(svg)]:pl-4">
|
110 | 74 | <Alert.Description>
|
111 | 75 | <span class="text-xl"> 🍾 </span>
|
|
120 | 84 | class="flex-1"
|
121 | 85 | label="Tokenizers"
|
122 | 86 | bind:value={tokenizer}
|
123 |
| - options={tokenizerList.map((v) => ({ value: v, label: v }))} |
| 87 | + options={tokenizerList} |
124 | 88 | tooltip={m.tokenizersTooltip()}
|
125 | 89 | />
|
| 90 | + <Select |
| 91 | + class="flex-1" |
| 92 | + label="Resume" |
| 93 | + bind:value={resume} |
| 94 | + options={resumeList} |
| 95 | + tooltip={m.resumeTooltip()} |
| 96 | + /> |
126 | 97 | </div>
|
127 | 98 | <ConfigForm key={modelType} config={modelConfig} bind:option={modelOption} />
|
128 | 99 | </section>
|
129 | 100 | <section class="py-4">
|
130 | 101 | <h3 class="mb-2 font-bold">Trainer Config</h3>
|
131 |
| - |
132 |
| - <Select |
133 |
| - label={m.dataset()} |
134 |
| - class="mb-4 w-96" |
135 |
| - bind:value={dataset} |
136 |
| - options={datasetList.map((v) => ({ value: v, label: v }))} |
137 |
| - /> |
138 |
| - |
| 102 | + <Select label={m.dataset()} class="mb-4 w-96" bind:value={dataset} options={datasetList} /> |
139 | 103 | <ConfigForm key={modelType} config={trainerConfig} bind:option={trainerOption} />
|
140 | 104 | </section>
|
141 | 105 |
|
142 | 106 | <h3 class="mb-2 mt-8 scroll-m-20 text-2xl font-semibold tracking-tight">
|
143 | 107 | {m.generatedCommand()}
|
144 | 108 | </h3>
|
145 | 109 |
|
146 |
| - <Code class="mt-8 max-h-96" lang="bash" content={command} wrap /> |
| 110 | + {#await command} |
| 111 | + <Code class="mt-8 max-h-96" lang="bash" content="Updating..." wrap /> |
| 112 | + {:then command} |
| 113 | + <Code class="mt-8 max-h-96" lang="bash" content={command} wrap /> |
| 114 | + {:catch error} |
| 115 | + <Code class="mt-8 max-h-96" lang="bash" content={error.detail} wrap /> |
| 116 | + {/await} |
147 | 117 |
|
148 | 118 | <div class="w-full text-right">
|
149 | 119 | <Dialog.Root>
|
|
163 | 133 | placeholder="Enter a name for the fork"
|
164 | 134 | class="w-full"
|
165 | 135 | variant="row"
|
166 |
| - bind:value={forkName} |
| 136 | + bind:value={fork} |
167 | 137 | />
|
168 | 138 | </div>
|
169 | 139 | <Dialog.Footer>
|
170 | 140 | <Button type="submit" onclick={createForkNetwork}>Create fork</Button>
|
171 | 141 | </Dialog.Footer>
|
172 | 142 | </Dialog.Content>
|
173 | 143 | </Dialog.Root>
|
174 |
| - <Button size="lg" class="mt-8" onclick={() => startTrain(command)}> |
| 144 | + <Button size="lg" class="mt-8" onclick={startTrain}> |
175 | 145 | {m.startTrain()}
|
176 | 146 | </Button>
|
177 | 147 | </div>
|
|
0 commit comments