Skip to content

Commit f22ed20

Browse files
committed
trainer 支持 resume
1 parent a649cd9 commit f22ed20

File tree

14 files changed

+451
-263
lines changed

14 files changed

+451
-263
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Project ignores
22
darkit/tmp
3+
test/test_datasets_temp
34

45

56
# Created by .ignore support plugin (hsz.mobi)

darkit/core/README.md

Lines changed: 0 additions & 12 deletions
This file was deleted.

darkit/core/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from . import utils
2-
from .trainer import Trainer, TrainerConfig, LogFieldnames
2+
from .trainer import Trainer, FabricTrainer, TrainerConfig, LogFieldnames
33
from .predicter import Predicter
44

55

66
__all__ = (
77
"utils",
88
"Trainer",
9+
"FabricTrainer",
910
"Predicter",
1011
"TrainerConfig",
1112
"LogFieldnames",

darkit/core/predicter.py

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import torch.nn as nn
44
from typing import Optional, Union
55
from pathlib import Path
6-
from .utils import MODEL_PATH
6+
from .lib.inject import inject_script
7+
from .utils import MODEL_PATH, model as model_utils
78

89

910
class Predicter:
@@ -47,54 +48,38 @@ def __new__(
4748
# 否则返回父类的实例
4849
return super().__new__(cls)
4950

51+
@classmethod
52+
def get_root(cls) -> Path:
53+
return MODEL_PATH
54+
5055
@classmethod
5156
def get_save_directory(cls, name: str) -> Path:
52-
save_directory = MODEL_PATH / name
57+
save_directory = cls.get_root() / name
5358
return save_directory
5459

60+
@classmethod
61+
def get_fork_directory(cls, fork: str) -> Optional[Path]:
62+
return model_utils.get_fork_directory(cls.get_root(), fork)
63+
5564
@classmethod
5665
def get_checkpoint(cls, name: str, checkpoint: Optional[str] = None) -> Path:
5766
save_directory = cls.get_save_directory(name)
58-
59-
# 寻找文件夹下的最新的 checkpoint 的 name
60-
if checkpoint:
61-
# check if the checkpoint exists
62-
if not (save_directory / f"{checkpoint}.pth").exists():
63-
raise FileNotFoundError(f"checkpoint {checkpoint} not found")
64-
return save_directory / f"{checkpoint}.pth"
65-
try:
66-
checkpoint_path = max(
67-
save_directory.glob("*.pth"), key=lambda x: x.stat().st_ctime
68-
)
69-
# 去掉后缀
70-
return checkpoint_path
71-
except ValueError:
72-
raise FileNotFoundError(f"checkpoint not found in {save_directory}")
67+
return model_utils.get_checkpoint(save_directory, checkpoint)
7368

7469
@classmethod
7570
def get_model_config_json(cls, name: str) -> dict:
7671
save_directory = cls.get_save_directory(name)
77-
with open(save_directory / "config.json", "r") as f:
78-
config_dict = json.load(f)
79-
return config_dict
72+
return model_utils.get_model_config_json(save_directory)
8073

8174
@classmethod
8275
def get_external_config_json(cls, name: str) -> Optional[dict]:
8376
save_directory = cls.get_save_directory(name)
84-
external_config_path = save_directory / "external_config.json"
85-
if external_config_path.exists():
86-
with open(external_config_path, "r") as f:
87-
config_dict = json.load(f)
88-
return config_dict
89-
else:
90-
return None
77+
return model_utils.get_external_config_json(save_directory)
9178

9279
@classmethod
9380
def get_trainer_config_json(cls, name: str) -> dict:
9481
save_directory = cls.get_save_directory(name)
95-
with open(save_directory / "trainer_config.json", "r") as f:
96-
config_dict = json.load(f)
97-
return config_dict
82+
return model_utils.get_trainer_config_json(save_directory)
9883

9984
@classmethod
10085
def get_model(cls, name: str, checkpoint: Optional[str] = None) -> nn.Module:
@@ -114,6 +99,16 @@ def get_model(cls, name: str, checkpoint: Optional[str] = None) -> nn.Module:
11499
print(f"Loading model {sub_class_name} from {checkpoint_path}")
115100
return sub_class.get_model(name, checkpoint)
116101

102+
@classmethod
103+
def inject_script(cls, model, name: str):
104+
external_config = cls.get_external_config_json(name)
105+
if external_config:
106+
fork = external_config.get("fork", None)
107+
fork_directory = cls.get_fork_directory(fork)
108+
if fork_directory is not None:
109+
model = inject_script(model, fork_directory)
110+
return model
111+
117112
@classmethod
118113
def from_pretrained(
119114
cls,
@@ -124,6 +119,7 @@ def from_pretrained(
124119
trainer_config = cls.get_trainer_config_json(name)
125120
device = device if device else trainer_config.get("device", "cuda")
126121
model = cls.get_model(name, checkpoint).to(device)
122+
model = cls.inject_script(model, name)
127123
return cls(name, model, device=device)
128124

129125
def _predict(self, ctx):

0 commit comments

Comments
 (0)