Skip to content

Commit a2f38a2

Browse files
authored
Merge pull request #1 from Yesifan/resume
Support resume train
2 parents 8fa64c7 + 58e30a7 commit a2f38a2

33 files changed

+830
-497
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Project ignores
22
darkit/tmp
3+
test/test_datasets_temp
34

45

56
# Created by .ignore support plugin (hsz.mobi)

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ pip install -r requirements.txt
3434
```
3535
#### Notes
3636
- When running `npm run dev`, pay attention to the `server.proxy` configuration in `vite.config.ts`, it should match the address of the `FastAPI` service you started.
37-
- If you encounter errors like *Couldn't import the plugin "https://cdn.jsdelivr.net/npm/@inlang/message-lint-rule-without-source@latest/dist/index.js"*, it may be that your network cannot access `cdn.jsdelivr.net`. Please find a proxy or acceleration node and replace it in `project.inlang/settings.json`.
37+
- If you encounter errors like *Couldn't import the plugin "https://xxxxxxxxx.xxx/xxxxxxxxxxxx"*, it may be that your network cannot access `cdn.jsdelivr.net`. Please find a proxy or acceleration node and replace it in `project.inlang/settings.json`.
3838

3939
### Unit Testing
4040
#### Writing Unit Tests

README.zh.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ pip install -r requirements.txt
3232
```
3333
#### 注意事项
3434
- 运行 `npm run dev` 时需要注意 `vite.config.ts` 中的 `server.proxy` 配置, 应该确保与你启动的 `FastAPI` 服务地址一致。
35-
- 如果出现 *Couldn't import the plugin "https://cdn.jsdelivr.net/npm/@inlang/message-lint-rule-without-source@latest/dist/index.js"* 类似的错误,可能是你的网络无法访问 `cdn.jsdelivr.net`,请自行寻找代理或者加速节点并在 `project.inlang/settings.json` 中进行替换。
35+
- 如果出现 *Couldn't import the plugin "https://xxxxxxxxx.xxx/xxxxxxxxxxxx"* 类似的错误,可能是你的网络无法访问 `cdn.jsdelivr.net`,请自行寻找代理或者加速节点并在 `project.inlang/settings.json` 中进行替换。
3636

3737
### 单元测试
3838
#### 编写单元测试

darkit/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.1.10"
1+
__version__ = "0.1.11"

darkit/core/README.md

Lines changed: 0 additions & 12 deletions
This file was deleted.

darkit/core/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from . import utils
2-
from .trainer import Trainer, TrainerConfig, LogFieldnames
2+
from .trainer import Trainer, FabricTrainer, TrainerConfig, LogFieldnames
33
from .predicter import Predicter
44

55

66
__all__ = (
77
"utils",
88
"Trainer",
9+
"FabricTrainer",
910
"Predicter",
1011
"TrainerConfig",
1112
"LogFieldnames",

darkit/core/predicter.py

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import torch.nn as nn
44
from typing import Optional, Union
55
from pathlib import Path
6-
from .utils import MODEL_PATH
6+
from .lib.inject import inject_script
7+
from .utils import MODEL_PATH, model as model_utils
78

89

910
class Predicter:
@@ -47,54 +48,38 @@ def __new__(
4748
# 否则返回父类的实例
4849
return super().__new__(cls)
4950

51+
@classmethod
52+
def get_root(cls) -> Path:
53+
return MODEL_PATH
54+
5055
@classmethod
5156
def get_save_directory(cls, name: str) -> Path:
52-
save_directory = MODEL_PATH / name
57+
save_directory = cls.get_root() / name
5358
return save_directory
5459

60+
@classmethod
61+
def get_fork_directory(cls, fork: str) -> Optional[Path]:
62+
return model_utils.get_fork_directory(cls.get_root(), fork)
63+
5564
@classmethod
5665
def get_checkpoint(cls, name: str, checkpoint: Optional[str] = None) -> Path:
5766
save_directory = cls.get_save_directory(name)
58-
59-
# 寻找文件夹下的最新的 checkpoint 的 name
60-
if checkpoint:
61-
# check if the checkpoint exists
62-
if not (save_directory / f"{checkpoint}.pth").exists():
63-
raise FileNotFoundError(f"checkpoint {checkpoint} not found")
64-
return save_directory / f"{checkpoint}.pth"
65-
try:
66-
checkpoint_path = max(
67-
save_directory.glob("*.pth"), key=lambda x: x.stat().st_ctime
68-
)
69-
# 去掉后缀
70-
return checkpoint_path
71-
except ValueError:
72-
raise FileNotFoundError(f"checkpoint not found in {save_directory}")
67+
return model_utils.get_checkpoint(save_directory, checkpoint)
7368

7469
@classmethod
7570
def get_model_config_json(cls, name: str) -> dict:
7671
save_directory = cls.get_save_directory(name)
77-
with open(save_directory / "config.json", "r") as f:
78-
config_dict = json.load(f)
79-
return config_dict
72+
return model_utils.get_model_config_json(save_directory)
8073

8174
@classmethod
8275
def get_external_config_json(cls, name: str) -> Optional[dict]:
8376
save_directory = cls.get_save_directory(name)
84-
external_config_path = save_directory / "external_config.json"
85-
if external_config_path.exists():
86-
with open(external_config_path, "r") as f:
87-
config_dict = json.load(f)
88-
return config_dict
89-
else:
90-
return None
77+
return model_utils.get_external_config_json(save_directory)
9178

9279
@classmethod
9380
def get_trainer_config_json(cls, name: str) -> dict:
9481
save_directory = cls.get_save_directory(name)
95-
with open(save_directory / "trainer_config.json", "r") as f:
96-
config_dict = json.load(f)
97-
return config_dict
82+
return model_utils.get_trainer_config_json(save_directory)
9883

9984
@classmethod
10085
def get_model(cls, name: str, checkpoint: Optional[str] = None) -> nn.Module:
@@ -114,6 +99,16 @@ def get_model(cls, name: str, checkpoint: Optional[str] = None) -> nn.Module:
11499
print(f"Loading model {sub_class_name} from {checkpoint_path}")
115100
return sub_class.get_model(name, checkpoint)
116101

102+
@classmethod
103+
def inject_script(cls, model, name: str):
104+
external_config = cls.get_external_config_json(name)
105+
if external_config:
106+
fork = external_config.get("fork", None)
107+
fork_directory = cls.get_fork_directory(fork)
108+
if fork_directory is not None:
109+
model = inject_script(model, fork_directory)
110+
return model
111+
117112
@classmethod
118113
def from_pretrained(
119114
cls,
@@ -124,6 +119,7 @@ def from_pretrained(
124119
trainer_config = cls.get_trainer_config_json(name)
125120
device = device if device else trainer_config.get("device", "cuda")
126121
model = cls.get_model(name, checkpoint).to(device)
122+
model = cls.inject_script(model, name)
127123
return cls(name, model, device=device)
128124

129125
def _predict(self, ctx):

0 commit comments

Comments
 (0)