@@ -295,5 +295,66 @@ def test_train_predict(self):
295
295
print (f"Model { self .model_name } deleted." )
296
296
297
297
298
+ # python -m unittest test.test_model.TestTrainerLMResume
299
+ class TestTrainerLMResume (unittest .TestCase ):
300
+ model_name = f"SpikeLM-Test-{ random .randint (1000 , 9999 )} "
301
+ # model_name = "SpikeLM-Test-1417"
302
+
303
+ def test_train_predict (self ):
304
+ from darkit .lm .models .SpikeLM import TrainerConfig , SpikeLMConfig , SpikeLM
305
+
306
+ tokenizer = AutoTokenizer .from_pretrained ("bert-base-uncased" )
307
+ m_conf = SpikeLMConfig (
308
+ vocab_size = tokenizer .vocab_size ,
309
+ hidden_size = 72 ,
310
+ num_hidden_layers = 6 ,
311
+ num_attention_heads = 6 ,
312
+ )
313
+ model = SpikeLM (m_conf )
314
+
315
+ t_conf = TrainerConfig (
316
+ name = self .model_name ,
317
+ batch_size = 2 ,
318
+ max_train_steps = 10 ,
319
+ )
320
+
321
+ # 训练模型
322
+ wikitext = load_dataset ("Salesforce/wikitext" , "wikitext-103-raw-v1" )
323
+ wikitext_train = wikitext ["train" ] # type: ignore
324
+
325
+ with Trainer (model , tokenizer = tokenizer , config = t_conf ) as trainer :
326
+ trainer .train (wikitext_train )
327
+
328
+ self .assertTrue (
329
+ trainer .is_name_exist ,
330
+ f"Model { self .model_name } not saved." ,
331
+ )
332
+ model = SpikeLM (m_conf )
333
+ t_conf2 = TrainerConfig (
334
+ name = self .model_name ,
335
+ batch_size = 2 ,
336
+ max_train_steps = 20 ,
337
+ )
338
+
339
+ with Trainer (
340
+ model , tokenizer = tokenizer , config = t_conf2 , resume = self .model_name
341
+ ) as trainer :
342
+ trainer .train (wikitext_train )
343
+
344
+ # 测试模型
345
+ ctx_len = 64
346
+ predicter = Predicter .from_pretrained (self .model_name )
347
+ prompt = "hello world"
348
+ print (prompt , end = "" )
349
+ for char in predicter .predict (prompt , ctx_len = ctx_len ):
350
+ print (char , end = "" , flush = True )
351
+ print ()
352
+
353
+ # 删除模型
354
+ if trainer .save_directory :
355
+ shutil .rmtree (trainer .save_directory )
356
+ print (f"Model { self .model_name } deleted." )
357
+
358
+
298
359
if __name__ == "__main__" :
299
360
unittest .main ()
0 commit comments