@@ -7,14 +7,12 @@ def run_epoch(self):
77
88
99class ProbingCallback (dml .Callback ):
10-
1110 def __init__ (self , pipe = None , stage = None ):
1211 self .pipe = pipe
1312 self .stage = stage
1413 self .pipe_test = False
1514 self .stage_test = False
1615
17-
1816 def pre_run (self , pipe ):
1917 self .pipe_test = dml .current_pipe () is self .pipe
2018
@@ -23,9 +21,8 @@ def pre_stage(self, stage):
2321
2422
2523class LogCallback (dml .Callback ):
26-
2724 def __init__ (self ):
28- self . i = 0
25+ self .i = 0
2926
3027 def pre_epoch (self , stage ):
3128 dml .log_metric ('test' , self .i )
@@ -40,19 +37,17 @@ def test_accessors(self, torch_distributed):
4037 pipe .append (stage1 )
4138 pipe .append (stage2 )
4239
43-
4440 cb1 = ProbingCallback (pipe )
45- cb2 = ProbingCallback (stage = stage1 )
46- cb3 = ProbingCallback (stage = stage2 )
41+ cb2 = ProbingCallback (stage = stage1 )
42+ cb3 = ProbingCallback (stage = stage2 )
4743
4844 pipe .add_callback (cb1 )
4945 stage1 .add_callback (cb2 )
5046 stage2 .add_callback (cb3 )
5147
52-
5348 assert dml .current_pipe () is None
5449 assert dml .current_stage () is None
55-
50+
5651 pipe .run ()
5752 assert cb1 .pipe_test
5853 assert cb2 .stage_test
@@ -61,7 +56,6 @@ def test_accessors(self, torch_distributed):
6156 assert dml .current_pipe () is None
6257 assert dml .current_stage () is None
6358
64-
6559 def test_logging (self , torch_distributed ):
6660 pipe = dml .Pipeline ()
6761 stage1 = DummyStage (epochs = 3 )
@@ -74,13 +68,7 @@ def test_logging(self, torch_distributed):
7468 pipe .run ()
7569
7670 assert 'test' in stage1 .history
77- assert list (stage1 .history ['test' ]) == [0 ,1 , 2 ]
78-
71+ assert list (stage1 .history ['test' ]) == [0 , 1 , 2 ]
72+
7973 assert 'test' in stage2 .history
8074 assert list (stage2 .history ['test' ]) == [3 ]
81-
82-
83-
84-
85- if __name__ == '__main__' :
86- sys .exit (pytest .main ([__file__ ]))
0 commit comments