|
1 | | -import torch.nn as nn |
| 1 | +from typing import Any, Dict, Optional, Tuple, Union |
2 | 2 |
|
3 | | -from omegaconf import DictConfig, OmegaConf |
| 3 | +import pytorch_lightning as pl |
| 4 | +import torch |
| 5 | +from pytorch_lightning.utilities import rank_zero_info |
| 6 | +from pytorch_lightning.utilities.exceptions import MisconfigurationException |
4 | 7 |
|
5 | | -from torch_geometric.data import Data |
| 8 | +from torch_points3d.core.instantiator import Instantiator |
| 9 | +from torch_points3d.core.config import OptimizerConfig, SchedulerConfig |
6 | 10 |
|
7 | 11 |
|
8 | | -class BaseModel(nn.Module): |
9 | | - def __init__(self, opt: DictConfig): |
10 | | - super(BaseModel, self).__init__() |
11 | | - self.opt = opt |
| 12 | +class PointCloudBaseModel(pl.LightningModule): |
| 13 | + def __init__( |
| 14 | + self, |
| 15 | + model: torch.nn.Module, |
| 16 | + optimizer: OptimizerConfig, |
| 17 | + scheduler: SchedulerConfig, |
| 18 | + instantiator: Instantiator, |
| 19 | + ): |
| 20 | + super().__init__() |
| 21 | + self.model = model |
| 22 | + # some optimizers/schedulers need parameters only known dynamically |
| 23 | + # allow users to override the getter to instantiate them lazily |
| 24 | + self.optimizer_cfg = optimizer |
| 25 | + self.scheduler_cfg = scheduler |
| 26 | + self.instantiator = instantiator |
12 | 27 |
|
13 | | - def set_input(self, data: Data): |
14 | | - """Unpack input data from the dataloader and perform necessary pre-processing steps. |
15 | | - Parameters: |
16 | | - input (dict): includes the data itself and its metadata information. |
| 28 | + def configure_optimizers(self) -> Dict: |
| 29 | + """Prepare optimizer and scheduler""" |
| 30 | + self.optimizer = self.instantiator.optimizer(self.model, self.optimizer_cfg) |
| 31 | + # compute_warmup needs the datamodule to be available when `self.num_training_steps` |
| 32 | + # is called that is why this is done here and not in the __init__ |
| 33 | + self.scheduler_cfg.num_training_steps, self.scheduler_cfg.num_warmup_steps = self.compute_warmup( |
| 34 | + num_training_steps=self.scheduler_cfg.num_training_steps, |
| 35 | + num_warmup_steps=self.scheduler_cfg.num_warmup_steps, |
| 36 | + ) |
| 37 | + rank_zero_info(f"Inferring number of training steps, set to {self.scheduler_cfg.num_training_steps}") |
| 38 | + rank_zero_info(f"Inferring number of warmup steps from ratio, set to {self.scheduler_cfg.num_warmup_steps}") |
| 39 | + self.scheduler = self.instantiator.scheduler(self.scheduler_cfg, self.optimizer) |
| 40 | + |
| 41 | + return { |
| 42 | + "optimizer": self.optimizer, |
| 43 | + "lr_scheduler": {"scheduler": self.scheduler, "interval": "step", "frequency": 1}, |
| 44 | + } |
| 45 | + |
| 46 | + @property |
| 47 | + def num_training_steps(self) -> int: |
| 48 | + """Total training steps inferred from datamodule and devices.""" |
| 49 | + if isinstance(self.trainer.limit_train_batches, int) and self.trainer.limit_train_batches != 0: |
| 50 | + dataset_size = self.trainer.limit_train_batches |
| 51 | + elif isinstance(self.trainer.limit_train_batches, float): |
| 52 | + # limit_train_batches is a percentage of batches |
| 53 | + dataset_size = len(self.trainer.datamodule.train_dataloader()) |
| 54 | + dataset_size = int(dataset_size * self.trainer.limit_train_batches) |
| 55 | + else: |
| 56 | + dataset_size = len(self.trainer.datamodule.train_dataloader()) |
| 57 | + |
| 58 | + num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes) |
| 59 | + if self.trainer.tpu_cores: |
| 60 | + num_devices = max(num_devices, self.trainer.tpu_cores) |
| 61 | + |
| 62 | + effective_batch_size = self.trainer.accumulate_grad_batches * num_devices |
| 63 | + max_estimated_steps = (dataset_size // effective_batch_size) * self.trainer.max_epochs |
| 64 | + |
| 65 | + if self.trainer.max_steps and self.trainer.max_steps < max_estimated_steps: |
| 66 | + return self.trainer.max_steps |
| 67 | + return max_estimated_steps |
| 68 | + |
| 69 | + def compute_warmup(self, num_training_steps: int, num_warmup_steps: Union[int, float]) -> Tuple[int, int]: |
| 70 | + if num_training_steps < 0: |
| 71 | + # less than 0 specifies to infer number of training steps |
| 72 | + num_training_steps = self.num_training_steps |
| 73 | + if isinstance(num_warmup_steps, float): |
| 74 | + # Convert float values to percentage of training steps to use as warmup |
| 75 | + num_warmup_steps *= num_training_steps |
| 76 | + return num_training_steps, num_warmup_steps |
| 77 | + |
| 78 | + def setup(self, stage: str): |
| 79 | + self.configure_metrics(stage) |
| 80 | + |
| 81 | + def configure_metrics(self, stage: str) -> Optional[Any]: |
| 82 | + """ |
| 83 | + Override to configure metrics for train/validation/test. |
| 84 | + This is called on fit start to have access to the data module, |
| 85 | + and initialize any data specific metrics. |
17 | 86 | """ |
18 | | - raise NotImplementedError |
|
0 commit comments