Skip to content

Commit 1d26dfe

Browse files
galrotemfacebook-github-bot
authored andcommitted
add progress to error messages (#779)
Summary: Pull Request resolved: #779 Reviewed By: JKSenthil Differential Revision: D55891686 fbshipit-source-id: 15741a2d7cea147169dc0a9b4ceb4cafe44766db
1 parent 23191d4 commit 1d26dfe

File tree

9 files changed

+88
-5
lines changed

9 files changed

+88
-5
lines changed

tests/framework/test_evaluate.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,16 @@ def test_evaluate_timing(self) -> None:
205205
)
206206
self.assertIn("evaluate.next(data_iter)", timer.recorded_durations.keys())
207207

208+
def test_error_message(self) -> None:
209+
with self.assertRaises(ValueError), self.assertLogs(level="INFO") as log:
210+
evaluate(EvalUnitWithError(), [1, 2, 3, 4])
211+
212+
self.assertIn(
213+
"INFO:torchtnt.framework.evaluate:Exception during evaluate after the following progress: "
214+
"completed epochs: 0, completed steps: 1, completed steps in current epoch: 1.:\nfoo",
215+
log.output,
216+
)
217+
208218

209219
class StopEvalUnit(EvalUnit[Tuple[torch.Tensor, torch.Tensor]]):
210220
def __init__(self, input_dim: int, steps_before_stopping: int) -> None:
@@ -235,4 +245,7 @@ def eval_step(
235245
return loss, outputs
236246

237247

238-
Batch = Tuple[torch.Tensor, torch.Tensor]
248+
class EvalUnitWithError(EvalUnit[int]):
249+
def eval_step(self, state: State, data: int) -> None:
250+
if self.eval_progress.num_steps_completed == 1:
251+
raise ValueError("foo")

tests/framework/test_fit.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,3 +328,30 @@ def test_fit_timing(self) -> None:
328328
)
329329
self.assertIn("train.next(data_iter)", timer.recorded_durations.keys())
330330
self.assertIn("evaluate.next(data_iter)", timer.recorded_durations.keys())
331+
332+
def test_error_message(self) -> None:
333+
self.maxDiff = None
334+
with self.assertRaises(ValueError), self.assertLogs(level="INFO") as log:
335+
fit(
336+
UnitWithError(),
337+
train_dataloader=[1, 2, 3, 4],
338+
eval_dataloader=[1, 2, 3, 4],
339+
max_steps=10,
340+
evaluate_every_n_epochs=1,
341+
)
342+
343+
self.assertIn(
344+
"INFO:torchtnt.framework.fit:Exception during fit after the following progress: train "
345+
"progress: completed epochs: 1, completed steps: 4, completed steps in current epoch: 0. "
346+
"eval progress: completed epochs: 0, completed steps: 2, completed steps in current epoch: 2.:\nfoo",
347+
log.output,
348+
)
349+
350+
351+
class UnitWithError(TrainUnit[int], EvalUnit[int]):
352+
def train_step(self, state: State, data: int) -> None:
353+
pass
354+
355+
def eval_step(self, state: State, data: int) -> None:
356+
if self.eval_progress.num_steps_completed == 2:
357+
raise ValueError("foo")

tests/framework/test_predict.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,16 @@ def test_predict_timing(self) -> None:
213213
)
214214
self.assertIn("predict.next(data_iter)", timer.recorded_durations.keys())
215215

216+
def test_error_message(self) -> None:
217+
with self.assertRaises(ValueError), self.assertLogs(level="INFO") as log:
218+
predict(PredictUnitWithError(), [1, 2, 3, 4])
219+
220+
self.assertIn(
221+
"INFO:torchtnt.framework.predict:Exception during predict after the following progress: "
222+
"completed epochs: 0, completed steps: 3, completed steps in current epoch: 3.:\nfoo",
223+
log.output,
224+
)
225+
216226

217227
Batch = Tuple[torch.Tensor, torch.Tensor]
218228

@@ -238,3 +248,9 @@ def predict_step(self, state: State, data: Batch) -> torch.Tensor:
238248

239249
self.steps_processed += 1
240250
return outputs
251+
252+
253+
class PredictUnitWithError(PredictUnit[int]):
254+
def predict_step(self, state: State, data: int) -> None:
255+
if self.predict_progress.num_steps_completed == 3:
256+
raise ValueError("foo")

tests/framework/test_train.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,16 @@ def test_train_timing(self) -> None:
263263
)
264264
self.assertIn("train.next(data_iter)", timer.recorded_durations.keys())
265265

266+
def test_error_message(self) -> None:
267+
with self.assertRaises(ValueError), self.assertLogs(level="INFO") as log:
268+
train(TrainUnitWithError(), [1, 2, 3, 4], max_steps=10)
269+
270+
self.assertIn(
271+
"INFO:torchtnt.framework.train:Exception during train after the following progress: "
272+
"completed epochs: 0, completed steps: 2, completed steps in current epoch: 2.:\nfoo",
273+
log.output,
274+
)
275+
266276

267277
Batch = Tuple[torch.Tensor, torch.Tensor]
268278

@@ -300,4 +310,10 @@ def train_step(
300310
return loss, outputs
301311

302312

313+
class TrainUnitWithError(TrainUnit[Batch]):
314+
def train_step(self, state: State, data: Batch) -> None:
315+
if self.train_progress.num_steps_completed == 2:
316+
raise ValueError("foo")
317+
318+
303319
Batch = Tuple[torch.Tensor, torch.Tensor]

torchtnt/framework/evaluate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ def evaluate(
9595
logger.info(get_timer_summary(state.timer))
9696
except Exception as e:
9797
# TODO: log for diagnostics
98-
logger.info(e)
98+
logger.info(
99+
f"Exception during evaluate after the following progress: {eval_unit.eval_progress.get_progress_string()}:\n{e}"
100+
)
99101
eval_unit.on_exception(state, e)
100102
callback_handler.on_exception(state, eval_unit, e)
101103
raise e

torchtnt/framework/fit.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,9 @@ def fit(
136136
logger.info(get_timer_summary(state.timer))
137137
except Exception as e:
138138
# TODO: log for diagnostics
139-
logger.info(f"Exception during fit:\n{e}")
139+
logger.info(
140+
f"Exception during fit after the following progress: train progress: {unit.train_progress.get_progress_string()} eval progress: {unit.eval_progress.get_progress_string()}:\n{e}"
141+
)
140142
unit.on_exception(state, e)
141143
callback_handler.on_exception(state, unit, e)
142144
raise e

torchtnt/framework/predict.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ def predict(
9595
logger.info(get_timer_summary(state.timer))
9696
except Exception as e:
9797
# TODO: log for diagnostics
98-
logger.info(f"Exception during predict:\n{e}")
98+
logger.info(
99+
f"Exception during predict after the following progress: {predict_unit.predict_progress.get_progress_string()}:\n{e}"
100+
)
99101
predict_unit.on_exception(state, e)
100102
callback_handler.on_exception(state, predict_unit, e)
101103
raise e

torchtnt/framework/train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,9 @@ def train(
105105
logger.info(get_timer_summary(state.timer))
106106
except Exception as e:
107107
# TODO: log for diagnostics
108-
logger.info(f"Exception during train:\n{e}")
108+
logger.info(
109+
f"Exception during train after the following progress: {train_unit.train_progress.get_progress_string()}:\n{e}"
110+
)
109111
train_unit.on_exception(state, e)
110112
callback_handler.on_exception(state, train_unit, e)
111113
raise e

torchtnt/utils/progress.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
6262
self._num_steps_completed = state_dict["num_steps_completed"]
6363
self._num_steps_completed_in_epoch = state_dict["num_steps_completed_in_epoch"]
6464

65+
def get_progress_string(self) -> str:
66+
return f"completed epochs: {self.num_epochs_completed}, completed steps: {self.num_steps_completed}, completed steps in current epoch: {self.num_steps_completed_in_epoch}."
67+
6568

6669
def estimated_steps_in_epoch(
6770
dataloader: Iterable[object],

0 commit comments

Comments
 (0)