|
12 | 12 | import shutil
|
13 | 13 | import tempfile
|
14 | 14 | import unittest
|
15 |
| -from typing import Any, Dict, Iterator, List, Optional |
| 15 | +from typing import Any, Dict, Iterator, List, Optional, Tuple |
16 | 16 | from unittest import mock
|
17 | 17 | from unittest.mock import MagicMock, patch
|
18 | 18 |
|
19 | 19 | import torch
|
| 20 | +from pyre_extensions import none_throws |
20 | 21 | from torch import nn
|
21 | 22 | from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter
|
| 23 | +from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader |
22 | 24 | from torch.distributed.checkpoint.default_planner import (
|
23 | 25 | DefaultLoadPlanner,
|
24 | 26 | DefaultSavePlanner,
|
|
28 | 30 | from torchsnapshot.test_utils import assert_state_dict_eq, check_state_dict_eq
|
29 | 31 | from torchtnt.framework._test_utils import (
|
30 | 32 | DummyAutoUnit,
|
| 33 | + DummyEvalUnit, |
| 34 | + DummyMeanMetric, |
31 | 35 | DummyMultiOptimUnit,
|
| 36 | + DummyPredictUnit, |
32 | 37 | DummyTrainUnit,
|
| 38 | + generate_dummy_stateful_dataloader, |
33 | 39 | generate_random_dataloader,
|
34 | 40 | get_dummy_train_state,
|
35 | 41 | )
|
36 | 42 | from torchtnt.framework.callbacks.checkpointer_types import KnobOptions, RestoreOptions
|
37 | 43 | from torchtnt.framework.callbacks.dcp_saver import DistributedCheckpointSaver
|
| 44 | +from torchtnt.framework.evaluate import evaluate |
| 45 | +from torchtnt.framework.fit import fit |
| 46 | +from torchtnt.framework.predict import predict |
38 | 47 |
|
39 | 48 | from torchtnt.framework.state import State
|
40 | 49 | from torchtnt.framework.train import train
|
| 50 | +from torchtnt.utils.checkpoint import get_latest_checkpoint_path |
41 | 51 | from torchtnt.utils.distributed import get_global_rank, spawn_multi_process
|
42 | 52 | from torchtnt.utils.env import seed
|
43 | 53 | from torchtnt.utils.test_utils import skip_if_not_distributed
|
@@ -490,6 +500,234 @@ def test_save_restore_multi_optimizers(self) -> None:
|
490 | 500 | my_unit_clone = DummyMultiOptimUnit(input_dim=input_dim)
|
491 | 501 | dcp_cb.restore_from_latest(temp_dir, my_unit_clone)
|
492 | 502 |
|
| 503 | + def test_save_predict(self) -> None: |
| 504 | + input_dim = 2 |
| 505 | + dataset_len = 10 |
| 506 | + batch_size = 2 |
| 507 | + |
| 508 | + my_unit = DummyPredictUnit(input_dim=input_dim) |
| 509 | + |
| 510 | + # pyre-ignore[16]: Add new attribute for testing |
| 511 | + my_unit.output_mean = DummyMeanMetric() |
| 512 | + |
| 513 | + # pyre-ignore[16]: Add at least one element to the metric |
| 514 | + my_unit.output_mean.update(1.0) |
| 515 | + |
| 516 | + dataloader = generate_dummy_stateful_dataloader( |
| 517 | + dataset_len, input_dim, batch_size |
| 518 | + ) |
| 519 | + |
| 520 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 521 | + dcp_cb = DistributedCheckpointSaver( |
| 522 | + temp_dir, |
| 523 | + knob_options=KnobOptions(1), |
| 524 | + save_every_n_predict_steps=2, |
| 525 | + ) |
| 526 | + |
| 527 | + predict(my_unit, dataloader, callbacks=[dcp_cb]) |
| 528 | + |
| 529 | + generated_ckpts = os.listdir(temp_dir) |
| 530 | + expected_ckpts = [ |
| 531 | + "epoch_0_predict_step_2", |
| 532 | + "epoch_0_predict_step_4", |
| 533 | + ] |
| 534 | + |
| 535 | + self.assertCountEqual(generated_ckpts, expected_ckpts) |
| 536 | + |
| 537 | + ckpt_path = none_throws(get_latest_checkpoint_path(temp_dir)) |
| 538 | + self.assertEqual(ckpt_path, os.path.join(temp_dir, expected_ckpts[-1])) |
| 539 | + |
| 540 | + storage_reader = FsspecReader(ckpt_path) |
| 541 | + metadata = storage_reader.read_metadata() |
| 542 | + self.assertCountEqual( |
| 543 | + # Get base keys after the app_state wrapper |
| 544 | + {key.split(".")[1] for key in metadata.state_dict_metadata.keys()}, |
| 545 | + [ |
| 546 | + "predict_progress", |
| 547 | + "predict_dataloader", |
| 548 | + "output_mean", |
| 549 | + ], |
| 550 | + ) |
| 551 | + |
| 552 | + def test_save_evaluate(self) -> None: |
| 553 | + input_dim = 2 |
| 554 | + dataset_len = 10 |
| 555 | + batch_size = 2 |
| 556 | + |
| 557 | + my_unit = DummyEvalUnit(input_dim=input_dim) |
| 558 | + |
| 559 | + dataloader = generate_dummy_stateful_dataloader( |
| 560 | + dataset_len, input_dim, batch_size |
| 561 | + ) |
| 562 | + |
| 563 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 564 | + dcp_cb = DistributedCheckpointSaver( |
| 565 | + temp_dir, |
| 566 | + knob_options=KnobOptions(1), |
| 567 | + save_every_n_eval_steps=2, |
| 568 | + ) |
| 569 | + |
| 570 | + evaluate(my_unit, dataloader, callbacks=[dcp_cb]) |
| 571 | + |
| 572 | + generated_ckpts = os.listdir(temp_dir) |
| 573 | + expected_ckpts = [ |
| 574 | + "epoch_0_eval_step_2", |
| 575 | + "epoch_0_eval_step_4", |
| 576 | + ] |
| 577 | + |
| 578 | + self.assertCountEqual(generated_ckpts, expected_ckpts) |
| 579 | + |
| 580 | + ckpt_path = none_throws(get_latest_checkpoint_path(temp_dir)) |
| 581 | + self.assertEqual(ckpt_path, os.path.join(temp_dir, expected_ckpts[-1])) |
| 582 | + |
| 583 | + storage_reader = FsspecReader(ckpt_path) |
| 584 | + metadata = storage_reader.read_metadata() |
| 585 | + self.assertCountEqual( |
| 586 | + # Get base keys after the app_state wrapper |
| 587 | + {key.split(".")[1] for key in metadata.state_dict_metadata.keys()}, |
| 588 | + [ |
| 589 | + "eval_progress", |
| 590 | + "eval_dataloader", |
| 591 | + ], |
| 592 | + ) |
| 593 | + |
| 594 | + def test_save_fit_eval_every_n_epochs(self) -> None: |
| 595 | + input_dim = 2 |
| 596 | + dataset_len = 10 |
| 597 | + batch_size = 2 |
| 598 | + |
| 599 | + my_unit = DummyAutoUnit(module=nn.Linear(input_dim, 2)) |
| 600 | + my_unit.output_mean = DummyMeanMetric() |
| 601 | + |
| 602 | + train_dataloader = generate_dummy_stateful_dataloader( |
| 603 | + dataset_len, input_dim, batch_size |
| 604 | + ) |
| 605 | + |
| 606 | + eval_dataloader = generate_dummy_stateful_dataloader( |
| 607 | + dataset_len, input_dim, batch_size |
| 608 | + ) |
| 609 | + |
| 610 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 611 | + dcp_cb = DistributedCheckpointSaver( |
| 612 | + temp_dir, |
| 613 | + knob_options=KnobOptions(1), |
| 614 | + save_every_n_train_steps=2, |
| 615 | + save_every_n_eval_steps=2, |
| 616 | + ) |
| 617 | + |
| 618 | + fit( |
| 619 | + my_unit, |
| 620 | + max_epochs=1, |
| 621 | + train_dataloader=train_dataloader, |
| 622 | + eval_dataloader=eval_dataloader, |
| 623 | + evaluate_every_n_epochs=1, |
| 624 | + callbacks=[dcp_cb], |
| 625 | + ) |
| 626 | + |
| 627 | + generated_ckpts = os.listdir(temp_dir) |
| 628 | + expected_ckpts = [ |
| 629 | + "epoch_0_train_step_2_eval_step_0", |
| 630 | + "epoch_0_train_step_4_eval_step_0", |
| 631 | + "epoch_1_train_step_5_eval_step_2", |
| 632 | + "epoch_1_train_step_5_eval_step_4", |
| 633 | + ] |
| 634 | + self.assertCountEqual(generated_ckpts, expected_ckpts) |
| 635 | + |
| 636 | + expected_dataloader = ["train_dataloader"] * 2 + ["eval_dataloader"] * 2 |
| 637 | + for ckpt_path, dl_key in zip(expected_ckpts, expected_dataloader): |
| 638 | + storage_reader = FsspecReader(os.path.join(temp_dir, ckpt_path)) |
| 639 | + metadata = storage_reader.read_metadata() |
| 640 | + self.assertCountEqual( |
| 641 | + # Get base keys after the app_state wrapper |
| 642 | + {key.split(".")[1] for key in metadata.state_dict_metadata.keys()}, |
| 643 | + [ |
| 644 | + "module", # Both train and eval checkpoints save full app_state in fit |
| 645 | + "optimizer", |
| 646 | + "lr_scheduler", |
| 647 | + "train_progress", |
| 648 | + "eval_progress", |
| 649 | + "predict_progress", # included because of AutoUnit |
| 650 | + dl_key, |
| 651 | + "output_mean", |
| 652 | + ], |
| 653 | + ) |
| 654 | + |
| 655 | + def test_save_fit_eval_every_n_steps(self) -> None: |
| 656 | + input_dim = 2 |
| 657 | + |
| 658 | + my_unit = DummyAutoUnit(module=nn.Linear(input_dim, 2)) |
| 659 | + my_unit.output_mean = DummyMeanMetric() |
| 660 | + |
| 661 | + train_dataloader = generate_dummy_stateful_dataloader(10, input_dim, 2) |
| 662 | + eval_dataloader = generate_dummy_stateful_dataloader(8, input_dim, 2) |
| 663 | + |
| 664 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 665 | + dcp_cb = DistributedCheckpointSaver( |
| 666 | + temp_dir, |
| 667 | + knob_options=KnobOptions(1), |
| 668 | + save_every_n_train_steps=2, |
| 669 | + save_every_n_eval_steps=2, |
| 670 | + ) |
| 671 | + |
| 672 | + fit( |
| 673 | + my_unit, |
| 674 | + max_epochs=1, |
| 675 | + train_dataloader=train_dataloader, |
| 676 | + eval_dataloader=eval_dataloader, |
| 677 | + evaluate_every_n_steps=2, |
| 678 | + evaluate_every_n_epochs=None, |
| 679 | + callbacks=[dcp_cb], |
| 680 | + ) |
| 681 | + |
| 682 | + generated_ckpts = os.listdir(temp_dir) |
| 683 | + expected_ckpts_to_dl_mapping: Dict[str, Tuple[str, ...]] = { |
| 684 | + # First train 2 steps |
| 685 | + "epoch_0_train_step_2_eval_step_0": ("train_dataloader",), |
| 686 | + # Then do a whole evaluation (4 steps) |
| 687 | + "epoch_0_train_step_2_eval_step_2": ( |
| 688 | + "train_dataloader", |
| 689 | + "eval_dataloader", |
| 690 | + ), |
| 691 | + "epoch_0_train_step_2_eval_step_4": ( |
| 692 | + "train_dataloader", |
| 693 | + "eval_dataloader", |
| 694 | + ), |
| 695 | + # Then train other two steps |
| 696 | + "epoch_0_train_step_4_eval_step_4": ("train_dataloader",), |
| 697 | + # Finally do a whole evaluation (4 steps) |
| 698 | + "epoch_0_train_step_4_eval_step_6": ( |
| 699 | + "train_dataloader", |
| 700 | + "eval_dataloader", |
| 701 | + ), |
| 702 | + "epoch_0_train_step_4_eval_step_8": ( |
| 703 | + "train_dataloader", |
| 704 | + "eval_dataloader", |
| 705 | + ), |
| 706 | + # Last checkpoint (on_train_end) |
| 707 | + "epoch_1_train_step_5_eval_step_8": (), |
| 708 | + } |
| 709 | + self.assertCountEqual( |
| 710 | + generated_ckpts, [*expected_ckpts_to_dl_mapping.keys()] |
| 711 | + ) |
| 712 | + |
| 713 | + for ckpt_path, expected_dls in expected_ckpts_to_dl_mapping.items(): |
| 714 | + storage_reader = FsspecReader(os.path.join(temp_dir, ckpt_path)) |
| 715 | + metadata = storage_reader.read_metadata() |
| 716 | + self.assertCountEqual( |
| 717 | + # Get base keys after the app_state wrapper |
| 718 | + {key.split(".")[1] for key in metadata.state_dict_metadata.keys()}, |
| 719 | + [ |
| 720 | + "module", # Both train and eval checkpoints save full app_state in fit |
| 721 | + "optimizer", |
| 722 | + "lr_scheduler", |
| 723 | + "train_progress", |
| 724 | + "eval_progress", |
| 725 | + "predict_progress", # included because of AutoUnit |
| 726 | + "output_mean", |
| 727 | + *expected_dls, |
| 728 | + ], |
| 729 | + ) |
| 730 | + |
493 | 731 |
|
494 | 732 | class DummyStatefulDataLoader:
|
495 | 733 | def __init__(self, dataloader: DataLoader) -> None:
|
|
0 commit comments