diff --git a/conf/dataset/default.yaml b/conf/dataset/default.yaml index c640689..0a561ce 100644 --- a/conf/dataset/default.yaml +++ b/conf/dataset/default.yaml @@ -1,17 +1,15 @@ # @package dataset -# cfg: - # torch data-loader specific arguments -cfg: +cfg: batch_size: ${training.batch_size} num_workers: ${training.num_workers} dataroot: data - common_transform: - aug_transform: - pre_transform: + # common_transform: + # aug_transform: + # pre_transform: - val_transform: "${dataset.cfg.common_transform}" - test_transform: "${dataset.cfg.val_transform}" - train_transform: - - "${dataset.cfg.aug_transform}" - - "${dataset.cfg.common_transform}" \ No newline at end of file + # val_transform: "${dataset.cfg.common_transform}" + # test_transform: "${dataset.cfg.val_transform}" + # train_transform: + # - "${dataset.cfg.aug_transform}" + # - "${dataset.cfg.common_transform}" \ No newline at end of file diff --git a/conf/dataset/segmentation/s3dis/s3dis1x1.yaml b/conf/dataset/segmentation/s3dis/s3dis1x1.yaml index 6743ba3..6f0c0aa 100644 --- a/conf/dataset/segmentation/s3dis/s3dis1x1.yaml +++ b/conf/dataset/segmentation/s3dis/s3dis1x1.yaml @@ -1,6 +1,7 @@ # @package dataset defaults: + - dataset_s3dis - segmentation/default _target_: torch_points3d.dataset.s3dis1x1.s3dis_data_module cfg: - fold: 5 \ No newline at end of file + fold : 5 diff --git a/conf/test_config.yaml b/conf/test_config.yaml new file mode 100644 index 0000000..1295ca7 --- /dev/null +++ b/conf/test_config.yaml @@ -0,0 +1,6 @@ +defaults: # loads default configs + - base_config + - dataset: segmentation/s3dis/s3dis1x1 + - training: default + +pretty_print: True \ No newline at end of file diff --git a/conf/training/default.yaml b/conf/training/default.yaml index 46e0864..3dffcff 100644 --- a/conf/training/default.yaml +++ b/conf/training/default.yaml @@ -1,3 +1,5 @@ +defaults: + - base_trainer lr: 5e-5 # read in dataset diff --git a/test.py b/test.py new file mode 100644 index 0000000..0a63651 --- /dev/null +++ b/test.py @@ -0,0 +1,74 @@ +import hydra +from hydra.core.global_hydra import GlobalHydra +from omegaconf import OmegaConf, DictConfig +from torch_points3d.trainer import LitTrainer +from torch_points3d.core.instantiator import HydraInstantiator, Instantiator +from dataclasses import dataclass +from hydra.core.config_store import ConfigStore +from typing import List, Any, Type +from omegaconf import MISSING, OmegaConf +from omegaconf._utils import is_structured_config + +OmegaConf.register_new_resolver("get_filename", lambda x: x.split("/")[-1]) + + +@dataclass +class TrainingDataConfig: + batch_size: int = 32 + num_workers: int = 0 + lr: float = MISSING + +# We seperate the dataset "cfg" from the actual dataset object +# so that we can pass the "cfg" into the dataset constructors as a DictConfig +# instead of as unwrapped parameters +@dataclass +class BaseDataConfig: + batch_size: int = 32 + num_workers: int = 0 + dataroot: str = "data" + +@dataclass +class BaseDataset: + _target_: str + cfg: BaseDataConfig + +@dataclass +class S3DISDataConfig(BaseDataConfig): + fold: int = 6 + +@dataclass +class S3DISDataset(BaseDataset): + cfg: S3DISDataConfig + +@dataclass +class Config: + dataset: Any + training: TrainingDataConfig + pretty_print: bool = False + +def show(x): + print(f"type: {type(x).__name__}, value: {repr(x)}") + +cs = ConfigStore.instance() +cs.store(name="base_config", node=Config) +cs.store(group="dataset", name="dataset_s3dis", node=S3DISDataset) +cs.store(group="training", name="base_trainer", node=TrainingDataConfig) + +@hydra.main(config_path="conf", config_name="test_config") +def main(cfg: DictConfig): + OmegaConf.set_struct(cfg, False) # This allows getattr and hasattr methods to function correctly + if cfg.get("pretty_print"): + print(OmegaConf.to_yaml(cfg, resolve=True)) + + dset = cfg.get("dataset") + show(dset) + show(dset.cfg) + dset_cfg = dset.cfg + # for some reason the cfg object will lose its typing information if hydra passes it to the target class + # so we pass it manually ourselves and keep the typing info + delattr(dset, "cfg") + hydra.utils.instantiate(dset, dset_cfg) + + +if __name__ == "__main__": + main() diff --git a/torch_points3d/dataset/s3dis1x1.py b/torch_points3d/dataset/s3dis1x1.py index f429fb0..d3cb326 100644 --- a/torch_points3d/dataset/s3dis1x1.py +++ b/torch_points3d/dataset/s3dis1x1.py @@ -1,5 +1,5 @@ from typing import Any, Callable, Dict, Optional, Sequence -from omegaconf import MISSING +from omegaconf import MISSING, DictConfig from dataclasses import dataclass import hydra.utils @@ -16,24 +16,29 @@ class S3DISDataConfig(PointCloudDataConfig): num_workers: int = 0 fold: int = 6 +def show(x): + print(f"type: {type(x).__name__}, value: {repr(x)}") class s3dis_data_module(PointCloudDataModule): - def __init__(self, cfg: S3DISDataConfig = S3DISDataConfig()) -> None: + def __init__(self, cfg: DictConfig) -> None: super().__init__(cfg) - - self.ds = { - "train": S3DIS1x1( - self.cfg.dataroot, - test_area=self.cfg.fold, - train=True, - pre_transform=self.cfg.pre_transform, - transform=self.cfg.train_transform, - ), - "test": S3DIS1x1( - self.cfg.dataroot, - test_area=self.cfg.fold, - train=False, - pre_transform=self.cfg.pre_transform, - transform=self.cfg.train_transform, - ), - } + show(cfg) + cfg.num_workers = "aj" + show(cfg) + # print("pre_transform: ", self.cfg.pre_transform) + # self.ds = { + # "train": S3DIS1x1( + # self.cfg.dataroot, + # test_area=self.cfg.fold, + # train=True, + # pre_transform=self.cfg.pre_transform, + # transform=self.cfg.train_transform, + # ), + # "test": S3DIS1x1( + # self.cfg.dataroot, + # test_area=self.cfg.fold, + # train=False, + # pre_transform=self.cfg.pre_transform, + # transform=self.cfg.train_transform, + # ), + # }