Skip to content

Commit 6946836

Browse files
committed
rename model=>models
1 parent 4e5659f commit 6946836

File tree

8 files changed

+87
-105
lines changed

8 files changed

+87
-105
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# @package dataset
22
defaults:
33
- segmentation/default
4-
_target_: torch_points3d.dataset.s3dis1x1.s3dis_data_module
4+
_target_: torch_points3d.datasets.s3dis1x1.s3dis_data_module
55
cfg:
66
fold: 5

conf/model/default.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# By default we turn off recursive instantiation, allowing the user to instantiate themselves at the appropriate times.
22
_recursive_: false
33

4-
_target_: torch_points3d.model.base_model.PointCloudBaseModel
4+
_target_: torch_points3d.models.base_model.PointCloudBaseModel
55
optimizer: ${optimizer}
66
scheduler: ${scheduler}

torch_points3d/core/instantiator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
import torch
77
from omegaconf import DictConfig
88

9-
from torch_points3d.dataset.base_dataset import PointCloudDataModule
9+
from torch_points3d.datasets.base_dataset import PointCloudDataModule
1010

1111
if TYPE_CHECKING:
1212
# avoid circular imports
13-
from torch_points3d.model.base_model import PointCloudBaseModel
13+
from torch_points3d.models.base_model import PointCloudBaseModel
1414

1515

1616
class Instantiator:
File renamed without changes.

torch_points3d/dataset/s3dis1x1.py renamed to torch_points3d/datasets/s3dis1x1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import hydra.utils
66
import torch_geometric.transforms as T
77
from torch.utils.data import DataLoader
8-
from torch_points3d.dataset.base_dataset import PointCloudDataModule, PointCloudDataConfig
8+
from torch_points3d.datasets.base_dataset import PointCloudDataModule, PointCloudDataConfig
99

1010
from torch_geometric.datasets import S3DIS as S3DIS1x1
1111

torch_points3d/model/base_model.py

Lines changed: 0 additions & 86 deletions
This file was deleted.

torch_points3d/models/base_model.py

Lines changed: 80 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,86 @@
1-
import torch.nn as nn
1+
from typing import Any, Dict, Optional, Tuple, Union
22

3-
from omegaconf import DictConfig, OmegaConf
3+
import pytorch_lightning as pl
4+
import torch
5+
from pytorch_lightning.utilities import rank_zero_info
6+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
47

5-
from torch_geometric.data import Data
8+
from torch_points3d.core.instantiator import Instantiator
9+
from torch_points3d.core.config import OptimizerConfig, SchedulerConfig
610

711

8-
class BaseModel(nn.Module):
9-
def __init__(self, opt: DictConfig):
10-
super(BaseModel, self).__init__()
11-
self.opt = opt
12+
class PointCloudBaseModel(pl.LightningModule):
13+
def __init__(
14+
self,
15+
model: torch.nn.Module,
16+
optimizer: OptimizerConfig,
17+
scheduler: SchedulerConfig,
18+
instantiator: Instantiator,
19+
):
20+
super().__init__()
21+
self.model = model
22+
# some optimizers/schedulers need parameters only known dynamically
23+
# allow users to override the getter to instantiate them lazily
24+
self.optimizer_cfg = optimizer
25+
self.scheduler_cfg = scheduler
26+
self.instantiator = instantiator
1227

13-
def set_input(self, data: Data):
14-
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
15-
Parameters:
16-
input (dict): includes the data itself and its metadata information.
28+
def configure_optimizers(self) -> Dict:
29+
"""Prepare optimizer and scheduler"""
30+
self.optimizer = self.instantiator.optimizer(self.model, self.optimizer_cfg)
31+
# compute_warmup needs the datamodule to be available when `self.num_training_steps`
32+
# is called that is why this is done here and not in the __init__
33+
self.scheduler_cfg.num_training_steps, self.scheduler_cfg.num_warmup_steps = self.compute_warmup(
34+
num_training_steps=self.scheduler_cfg.num_training_steps,
35+
num_warmup_steps=self.scheduler_cfg.num_warmup_steps,
36+
)
37+
rank_zero_info(f"Inferring number of training steps, set to {self.scheduler_cfg.num_training_steps}")
38+
rank_zero_info(f"Inferring number of warmup steps from ratio, set to {self.scheduler_cfg.num_warmup_steps}")
39+
self.scheduler = self.instantiator.scheduler(self.scheduler_cfg, self.optimizer)
40+
41+
return {
42+
"optimizer": self.optimizer,
43+
"lr_scheduler": {"scheduler": self.scheduler, "interval": "step", "frequency": 1},
44+
}
45+
46+
@property
47+
def num_training_steps(self) -> int:
48+
"""Total training steps inferred from datamodule and devices."""
49+
if isinstance(self.trainer.limit_train_batches, int) and self.trainer.limit_train_batches != 0:
50+
dataset_size = self.trainer.limit_train_batches
51+
elif isinstance(self.trainer.limit_train_batches, float):
52+
# limit_train_batches is a percentage of batches
53+
dataset_size = len(self.trainer.datamodule.train_dataloader())
54+
dataset_size = int(dataset_size * self.trainer.limit_train_batches)
55+
else:
56+
dataset_size = len(self.trainer.datamodule.train_dataloader())
57+
58+
num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes)
59+
if self.trainer.tpu_cores:
60+
num_devices = max(num_devices, self.trainer.tpu_cores)
61+
62+
effective_batch_size = self.trainer.accumulate_grad_batches * num_devices
63+
max_estimated_steps = (dataset_size // effective_batch_size) * self.trainer.max_epochs
64+
65+
if self.trainer.max_steps and self.trainer.max_steps < max_estimated_steps:
66+
return self.trainer.max_steps
67+
return max_estimated_steps
68+
69+
def compute_warmup(self, num_training_steps: int, num_warmup_steps: Union[int, float]) -> Tuple[int, int]:
70+
if num_training_steps < 0:
71+
# less than 0 specifies to infer number of training steps
72+
num_training_steps = self.num_training_steps
73+
if isinstance(num_warmup_steps, float):
74+
# Convert float values to percentage of training steps to use as warmup
75+
num_warmup_steps *= num_training_steps
76+
return num_training_steps, num_warmup_steps
77+
78+
def setup(self, stage: str):
79+
self.configure_metrics(stage)
80+
81+
def configure_metrics(self, stage: str) -> Optional[Any]:
82+
"""
83+
Override to configure metrics for train/validation/test.
84+
This is called on fit start to have access to the data module,
85+
and initialize any data specific metrics.
1786
"""
18-
raise NotImplementedError

torch_points3d/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
# from hydra.utils.instantiate as hydra_instantiate
88

9-
from torch_points3d.model.base_model import PointCloudBaseModel
10-
from torch_points3d.dataset.base_dataset import PointCloudDataModule, PointCloudDataConfig, PointCloudDataModule
9+
from torch_points3d.models.base_model import PointCloudBaseModel
10+
from torch_points3d.datasets.base_dataset import PointCloudDataModule, PointCloudDataConfig, PointCloudDataModule
1111
from torch_points3d.core.instantiator import HydraInstantiator, Instantiator
1212
from torch_points3d.core.config import TaskConfig, TrainerConfig
1313

0 commit comments

Comments
 (0)