Skip to content

Commit 10bdfe6

Browse files
committed
update SpikeLM resume train unittest
1 parent 71e3a03 commit 10bdfe6

File tree

2 files changed

+67
-1
lines changed

2 files changed

+67
-1
lines changed

darkit/lm/models/SpikeLM/trainer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
from tqdm import tqdm
3+
from itertools import cycle
34
from torch.utils.data import DataLoader
45
from transformers import get_scheduler, PreTrainedTokenizer, PreTrainedTokenizerFast
56
from spikingjelly.activation_based import functional
@@ -123,7 +124,11 @@ def train(self, train_dataset, val_dataset=None):
123124
val_dataloader = self.fabric.setup_dataloaders(val_dataloader)
124125

125126
max_train_steps = self.config.max_train_steps
126-
pbar = tqdm(train_dataloader, total=max_train_steps, desc="Training")
127+
pbar = tqdm(
128+
cycle(train_dataloader),
129+
total=max_train_steps - self.current_step,
130+
desc="Training",
131+
)
127132
for batch in pbar:
128133
if self.current_step >= self.config.max_train_steps:
129134
break

test/test_model.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,5 +295,66 @@ def test_train_predict(self):
295295
print(f"Model {self.model_name} deleted.")
296296

297297

298+
# python -m unittest test.test_model.TestTrainerLMResume
299+
class TestTrainerLMResume(unittest.TestCase):
300+
model_name = f"SpikeLM-Test-{random.randint(1000, 9999)}"
301+
# model_name = "SpikeLM-Test-1417"
302+
303+
def test_train_predict(self):
304+
from darkit.lm.models.SpikeLM import TrainerConfig, SpikeLMConfig, SpikeLM
305+
306+
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
307+
m_conf = SpikeLMConfig(
308+
vocab_size=tokenizer.vocab_size,
309+
hidden_size=72,
310+
num_hidden_layers=6,
311+
num_attention_heads=6,
312+
)
313+
model = SpikeLM(m_conf)
314+
315+
t_conf = TrainerConfig(
316+
name=self.model_name,
317+
batch_size=2,
318+
max_train_steps=10,
319+
)
320+
321+
# 训练模型
322+
wikitext = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1")
323+
wikitext_train = wikitext["train"] # type: ignore
324+
325+
with Trainer(model, tokenizer=tokenizer, config=t_conf) as trainer:
326+
trainer.train(wikitext_train)
327+
328+
self.assertTrue(
329+
trainer.is_name_exist,
330+
f"Model {self.model_name} not saved.",
331+
)
332+
model = SpikeLM(m_conf)
333+
t_conf2 = TrainerConfig(
334+
name=self.model_name,
335+
batch_size=2,
336+
max_train_steps=20,
337+
)
338+
339+
with Trainer(
340+
model, tokenizer=tokenizer, config=t_conf2, resume=self.model_name
341+
) as trainer:
342+
trainer.train(wikitext_train)
343+
344+
# 测试模型
345+
ctx_len = 64
346+
predicter = Predicter.from_pretrained(self.model_name)
347+
prompt = "hello world"
348+
print(prompt, end="")
349+
for char in predicter.predict(prompt, ctx_len=ctx_len):
350+
print(char, end="", flush=True)
351+
print()
352+
353+
# 删除模型
354+
if trainer.save_directory:
355+
shutil.rmtree(trainer.save_directory)
356+
print(f"Model {self.model_name} deleted.")
357+
358+
298359
if __name__ == "__main__":
299360
unittest.main()

0 commit comments

Comments
 (0)