Skip to content

Commit 2c8a68a

Browse files
committed
文档内容更新
1 parent 6a1d312 commit 2c8a68a

16 files changed

+379
-284
lines changed

darkit/core/predicter.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,13 @@ def from_pretrained(
116116
device: Optional[str] = None,
117117
checkpoint: Optional[str] = None,
118118
):
119+
"""
120+
不可以直接从父 Predicter 类实例化,需要保证 root 路径一致,否则无法根据 name 找到模型
121+
"""
122+
model = cls.get_model(name, checkpoint)
123+
model = cls.inject_script(model, name)
119124
trainer_config = cls.get_trainer_config_json(name)
120125
device = device if device else trainer_config.get("device", "cuda")
121-
model = cls.get_model(name, checkpoint).to(device)
122-
model = cls.inject_script(model, name)
123126
return cls(name, model, device=device)
124127

125128
def _predict(self, ctx):

darkit/core/trainer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def __new__(
197197

198198
@property
199199
def root(self) -> Path:
200-
return MODEL_PATH
200+
return MODEL_PATH / "base"
201201

202202
@property
203203
def save_directory(self) -> Optional[Path]:
@@ -305,13 +305,13 @@ def _save_external_config(self):
305305
def _copy_model_code(self):
306306
try:
307307
if self.save_directory:
308-
model_py_path = inspect.getfile(self.model.__class__)
309-
with open(model_py_path, "r", encoding="utf-8") as f:
310-
model_source_code = f.read()
311-
with open(self.model_code_archive_path, "w", encoding="utf-8") as f:
312-
f.write(model_source_code)
308+
model_source_code = inspect.getsource(self.model.__class__)
309+
with open(self.model_code_archive_path, "w", encoding="utf-8") as f:
310+
f.write(model_source_code)
313311
except OSError as e:
314312
print("Save model code failed:", e)
313+
except TypeError as e:
314+
print("Cannot retrieve source code for built-in class:", e)
315315

316316
def save_pretrained(self, check_poinent="complete"):
317317
"""

darkit/lm/command.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,19 @@ def show():
2626

2727
click.echo("TRAINED MODELS:")
2828
# 读取 MODEL_PATH 下的模型文件夹,输出模型名称
29-
for i, model in enumerate(MODEL_PATH.iterdir()):
30-
if model.is_dir(): # 排除__options__.json文件
29+
for model in MODEL_PATH.iterdir():
30+
if model.is_dir() and any(
31+
version.suffix == ".pth" for version in model.iterdir()
32+
): # 只展示包含 .pth 文件的文件夹
3133
click.echo(f" - {model.name}")
3234
# 模型文件夹下的每个 pth 文件都是一个版本的训练好的模型权重
3335
# 以 model:version 的形式输出每个版本
34-
for j, version in enumerate(model.iterdir()):
36+
i = 1
37+
# 按照文件修改时间排序
38+
for version in sorted(model.iterdir(), key=lambda x: x.stat().st_mtime):
3539
if version.suffix == ".pth":
36-
click.echo(f" {j + 1}. {version.stem}")
40+
click.echo(f" {i}. {version.stem}")
41+
i += 1
3742
click.echo()
3843

3944

@@ -62,7 +67,7 @@ def predict(model_type: str, model_name: str, prompt: str, device: str, ctx_len:
6267
Examples: darkit predict SpikeGPT SpikeGPT:complete "I am" --tokenizer gpt2 --ctx_len 512
6368
"""
6469
import torch
65-
from darkit.core import Predicter
70+
from .main import Predicter
6671

6772
# model_name = MODEL_NAME:MODEL_VERSION
6873
# 把 model_name 拆分为 model_name 和 version,如果没有 version 则默认为 'complete'

docs/2.User-guide/2.How-use-web.md

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ After setting the parameters of the model, click the `Train` button at the botto
2222

2323
After the training is completed, the model will be saved (if the user chooses to save the model in the settings) to the DarwinKit model directory (default is `~/.cache/DarwinKit`, which can be modified by changing the `DSPIKE_HOME` environment variable). Then users can use the model for prediction on the prediction page or view the training logs and parameters of the model on the visualization page.
2424

25+
#### Resume Training
26+
If users need to resume training based on the weights of a previously trained model, they only need to select the previously trained model in the `Resume` dropdown box on the training page, and then click the `Train` button to continue training from the last checkpoint.
27+
2528
### Predict Models
2629
Users can use trained models for prediction. On the prediction page, users can select a trained model, then input text, and click the predict button to get the prediction result of the model.
2730

@@ -30,11 +33,21 @@ All trained models will be displayed in the `Model Name` dropdown box, and users
3033

3134
After starting the prediction, the output of the model will be displayed on the page in real-time.
3235

36+
### Model Forking
37+
The model forking feature provides developers with a new way to customize models. By using the `Fork` operation, you can create a forked version based on an existing model and edit, manage, and train it. Users select an existing model to `Fork`, creating a new forked model that ensures all subsequent operations are independent of the original model, protecting the integrity and security of the original model.
38+
39+
**User Guide**
40+
1. On the train models page, select a model and set the relevant parameters. Click the Fork button, enter the name of the forked model in the pop-up box, and click Create fork in the pop-up box to create the forked model.
41+
![Fork Step 1](/static/docs/fork/step1.png)
42+
2. After creating the forked model, it will automatically jump to the model editing page. The introduction of the editing page is as follows.
43+
![Fork Step 2](/static/docs/fork/step2.png)
44+
3. After editing, click the fork in the sidebar to enter the management page of the forked model. In the Forked Model dropdown box, you can select the forked model. After selecting the model, you can click the `View & Edit` button to enter the editing page of the forked model, or click the `Train` button to train the model.
45+
![Fork Step 3](/static/docs/fork/step3.png)
46+
3347
### Model Visualization
3448
Users can view the training logs and parameters of trained models. On the visualization page, users can select trained models (multiple models can be selected for data comparison), and click the view button to see the visualized charts of the parameters.
3549

3650
The schematic diagram is as follows:
3751
![model visual](/static/docs/visual.jpg)
3852

39-
If the model is in training, the page will update the data in real-time.
40-
53+
If the model is in training, the page will update the data in real-time.

docs/2.User-guide/3.How-use-cli.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,19 @@ Options:
1414

1515
Commands:
1616
create-options Generate the configuration file for the model.
17-
predict Use the trained SNN model for inference. Optional model types can be viewed using the command DarwinKit show...
17+
predict Use the trained SNN model for inference. Optional model types can be viewed using the command darkit show...
1818
show Display the available model_types, datasets, or...
1919
start Start the WEB service.
2020
train Train the SNN model.
2121
```
2222

23+
## Example
24+
```bash
25+
# Train the model
26+
darkit lm train --tokenizer openai-community/gpt2 --dataset Salesforce/wikitext:wikitext-103-raw-v1 SpikeGPT --vocab_size 30500 --ctx_len 1024
27+
# Use the model for prediction
28+
darkit lm predict SpikeGPT $model_name $prompt --device cuda
29+
# View trained models
30+
darkit lm show
31+
```
32+

docs/2.User-guide/4.How-use-model.md

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ n_embd = 768
3333

3434
config = SpikeGPTConfig(
3535
tokenizer.vocab_size,
36-
train_dataset.ctx_len,
37-
model_type=model_type,
38-
n_layer=n_layer,
39-
n_embd=n_embd,
36+
ctx_len=ctx_len,
37+
model_type="RWKV",
38+
n_layer=12,
39+
n_embd=768,
4040
)
4141
model = SpikeGPT(config).cuda()
4242
```
@@ -50,19 +50,19 @@ from darkit import Trainer
5050
from darkit.models import TrainerConfig
5151

5252
# Parameter configuration
53-
model_name = f"GPT-Test-Train-{random.randint(1000, 9999)}"
53+
model_name = "GPT-1"
5454
tconf = TrainerConfig(
5555
name=model_name,
56-
device="cuda",
56+
device=device,
5757
max_epochs=1,
5858
epoch_length_fixed=100,
5959
batch_size=2,
60-
epoch_save_frequency=1,
60+
save_step_interval=1,
6161
)
6262
# Configure the model, dataset, and tokenizer
63-
trainer = Trainer(model, tokenizer=tokenizer, config=tconf)
64-
# Start training
65-
trainer.train(train_dataset=wikitext_train)
63+
with Trainer(model, tokenizer=tokenizer, config=tconf) as trainer:
64+
# Start training
65+
trainer.train(train_dataset=wikitext_train)
6666
```
6767
The `TrainerConfig` class is used to configure the training parameters. Specific parameters can be referenced in the definition of the `TrainerConfig` class.
6868

@@ -71,13 +71,13 @@ The `TrainerConfig` class is used to configure the training parameters. Specific
7171
### Saving and Loading the Model
7272
During model training, the logic for saving the model is generally controlled according to the settings in `TrainerConfig`. For example, in the `TrainerConfig` of `SpikeGPT`, we can set `save_step_interval` to control the interval for saving the model.
7373

74-
The path for saving the model is determined based on the values of `tconf.name` and the `DSPIKE_HOME` environment variable.
74+
The path for saving the model is determined based on the values of `tconf.name` and the `DARWIN_KIT_HOME` environment variable.
7575

7676
### Generating Text
7777
After training is complete, the trained model can be loaded using the model name set during training. We can use the following code to generate text:
7878

7979
```python
80-
from darkit import Predicter
80+
from darkit.lm.main import Predicter
8181
predicter = Predicter.from_pretrained(model_name)
8282

8383
prompt = "I am"
@@ -93,3 +93,48 @@ We can use the `predict` method to generate text. The `predict` method accepts a
9393
The schematic diagram is as follows:
9494

9595
![SpikeGPT Run](/static/docs/SpikeGPTRun.gif)
96+
97+
## Complete Code
98+
```python
99+
from datasets import load_dataset
100+
from transformers import AutoTokenizer, GPT2Tokenizer
101+
from darkit.lm.main import Trainer, Predicter
102+
from darkit.lm.models.SpikeGPT import SpikeGPT, SpikeGPTConfig, TrainerConfig
103+
104+
device = "cuda"
105+
ctx_len = 64
106+
107+
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
108+
tokenizer.pad_token = tokenizer.eos_token
109+
110+
wikitext = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1")
111+
wikitext_train = wikitext["train"] # type: ignore
112+
113+
model_name = "GPT-1"
114+
config = SpikeGPTConfig(
115+
tokenizer.vocab_size,
116+
ctx_len=ctx_len,
117+
model_type="RWKV",
118+
n_layer=12,
119+
n_embd=768,
120+
)
121+
model = SpikeGPT(config)
122+
tconf = TrainerConfig(
123+
name=model_name,
124+
device=device,
125+
max_epochs=1,
126+
epoch_length_fixed=100,
127+
batch_size=2,
128+
save_step_interval=1,
129+
)
130+
with Trainer(model, tokenizer=tokenizer, config=tconf) as trainer:
131+
trainer.train(train_dataset=wikitext_train)
132+
133+
# Test the model
134+
predicter = Predicter.from_pretrained(model_name)
135+
prompt = "hello world"
136+
print(prompt, end="")
137+
for char in predicter.predict(prompt, ctx_len=ctx_len):
138+
print(char, end="", flush=True)
139+
print()
140+
```

0 commit comments

Comments
 (0)