Skip to content

Commit d3d13dc

Browse files
committed
initial commit
0 parents  commit d3d13dc

File tree

8 files changed

+345
-0
lines changed

8 files changed

+345
-0
lines changed

conf/dataset/default.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# @package dataset
2+
_target_: lightning_transformers.core.data.TransformerDataModule
3+
cfg:
4+
# torch data-loader specific arguments
5+
batch_size: ${training.batch_size}
6+
num_workers: ${training.num_workers}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# @package dataset
2+
defaults:
3+
- /dataset/default
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# @package dataset
2+
defaults:
3+
- segmentation/default
4+
_target_: lightning_transformers.task.nlp.multiple_choice.RaceMultipleChoiceDataModule
5+
cfg:
6+
dataset_name: race
7+
dataset_config_name: 'all'
8+
padding: False
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import logging
2+
from typing import Optional, TYPE_CHECKING, Union
3+
4+
import hydra
5+
import pytorch_lightning as pl
6+
import torch
7+
from omegaconf import DictConfig
8+
9+
from lightning_transformers.core import TransformerDataModule
10+
from lightning_transformers.core.data import TokenizerDataModule
11+
12+
if TYPE_CHECKING:
13+
# avoid circular imports
14+
from lightning_transformers.core import TaskTransformer
15+
16+
17+
class Instantiator:
18+
19+
def model(self, *args, **kwargs):
20+
raise NotImplementedError("Child class must implement method")
21+
22+
def optimizer(self, *args, **kwargs):
23+
raise NotImplementedError("Child class must implement method")
24+
25+
def scheduler(self, *args, **kwargs):
26+
raise NotImplementedError("Child class must implement method")
27+
28+
def data_module(self, *args, **kwargs):
29+
raise NotImplementedError("Child class must implement method")
30+
31+
def logger(self, *args, **kwargs):
32+
raise NotImplementedError("Child class must implement method")
33+
34+
def trainer(self, *args, **kwargs):
35+
raise NotImplementedError("Child class must implement method")
36+
37+
def instantiate(self, *args, **kwargs):
38+
raise NotImplementedError("Child class must implement method")
39+
40+
41+
class HydraInstantiator(Instantiator):
42+
43+
def model(
44+
self,
45+
cfg: DictConfig,
46+
model_data_kwargs: Optional[DictConfig] = None,
47+
tokenizer: Optional[DictConfig] = None,
48+
pipeline_kwargs: Optional[DictConfig] = None
49+
) -> "TaskTransformer":
50+
if model_data_kwargs is None:
51+
model_data_kwargs = {}
52+
model_data_kwargs = dict(model_data_kwargs) # avoid ConfigKeyError: Key 'tokenizer' is not in struct`
53+
54+
# use `model_data_kwargs` to pass `tokenizer` and `pipeline_kwargs`
55+
# as not all models might contain these parameters.
56+
if tokenizer:
57+
model_data_kwargs["tokenizer"] = self.instantiate(tokenizer)
58+
if pipeline_kwargs:
59+
model_data_kwargs["pipeline_kwargs"] = pipeline_kwargs
60+
61+
return self.instantiate(cfg, instantiator=self, **model_data_kwargs)
62+
63+
def optimizer(self, model: torch.nn.Module, cfg: DictConfig) -> torch.optim.Optimizer:
64+
no_decay = ["bias", "LayerNorm.weight"]
65+
grouped_parameters = [
66+
{
67+
"params": [
68+
p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad
69+
],
70+
"weight_decay": cfg.weight_decay,
71+
},
72+
{
73+
"params": [
74+
p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad
75+
],
76+
"weight_decay": 0.0,
77+
},
78+
]
79+
return self.instantiate(cfg, grouped_parameters)
80+
81+
def scheduler(self, cfg: DictConfig, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler:
82+
return self.instantiate(cfg, optimizer=optimizer)
83+
84+
def data_module(
85+
self,
86+
cfg: DictConfig,
87+
tokenizer: Optional[DictConfig] = None,
88+
) -> Union[TransformerDataModule, TokenizerDataModule]:
89+
if tokenizer:
90+
return self.instantiate(cfg, tokenizer=self.instantiate(tokenizer))
91+
return self.instantiate(cfg)
92+
93+
def logger(self, cfg: DictConfig) -> Optional[logging.Logger]:
94+
if cfg.get("log"):
95+
if isinstance(cfg.trainer.logger, bool):
96+
return cfg.trainer.logger
97+
return self.instantiate(cfg.trainer.logger)
98+
99+
def trainer(self, cfg: DictConfig, **kwargs) -> pl.Trainer:
100+
return self.instantiate(cfg, **kwargs)
101+
102+
def instantiate(self, *args, **kwargs):
103+
return hydra.utils.instantiate(*args, **kwargs)
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from typing import Any, Callable, Dict, Optional
2+
3+
import pytorch_lightning as pl
4+
from torch.utils.data import DataLoader
5+
6+
from lightning_transformers.core.config import TransformerDataConfig
7+
8+
class TransformerDataModule(pl.LightningDataModule):
9+
10+
def __init__(self, cfg: TransformerDataConfig = TransformerDataConfig()) -> None:
11+
super().__init__()
12+
self.cfg = cfg
13+
self.ds = None
14+
15+
def train_dataloader(self) -> DataLoader:
16+
return DataLoader(
17+
self.ds["train"],
18+
batch_size=self.batch_size,
19+
num_workers=self.cfg.num_workers,
20+
collate_fn=self.collate_fn,
21+
)
22+
23+
def val_dataloader(self) -> DataLoader:
24+
return DataLoader(
25+
self.ds["validation"],
26+
batch_size=self.batch_size,
27+
num_workers=self.cfg.num_workers,
28+
collate_fn=self.collate_fn,
29+
)
30+
31+
def test_dataloader(self) -> Optional[DataLoader]:
32+
if "test" in self.ds:
33+
return DataLoader(
34+
self.ds["test"],
35+
batch_size=self.batch_size,
36+
num_workers=self.cfg.num_workers,
37+
collate_fn=self.collate_fn,
38+
)
39+
40+
@property
41+
def batch_size(self) -> int:
42+
return self.cfg.batch_size
43+
44+
@property
45+
def collate_fn(self) -> Optional[Callable]:
46+
return None
47+
48+
@property
49+
def model_data_kwargs(self) -> Dict:
50+
"""
51+
Override to provide the model with additional kwargs.
52+
This is useful to provide the number of classes/pixels to the model or any other data specific args
53+
Returns: Dict of args
54+
"""
55+
return {}
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
2+
3+
class BaseDataset(TransformerDataModule):
4+
5+
NAME = ...
6+
7+
def __init__(
8+
self,
9+
*args,
10+
**kwargs,
11+
):
12+
13+
self._threshold = kwargs.get("threshold", None)
14+
self.__instantiate_transform(kwargs)
15+
self.clean_kwargs(kwargs)
16+
TransformerDataModule.__init__(self, *args, **kwargs)
17+
18+
self.dataset_train = None
19+
self.dataset_val = None
20+
self.dataset_test = None
21+
22+
self._seed = 42
23+
self._num_workers = 2
24+
self._shuffle = True
25+
self._drop_last = False
26+
self._pin_memory = True
27+
self._follow_batch = []
28+
29+
self._hyper_parameters = {}
30+
31+
def __handle_mixin(self):
32+
pass
33+
34+
def clean_kwargs(self, kwargs):
35+
LightningDataModuleArgs = inspect.getargspec(LightningDataModule.__init__).args
36+
keys = list(kwargs.keys())
37+
for key in keys:
38+
if key not in LightningDataModuleArgs:
39+
del_attr(kwargs, key)
40+
41+
@property
42+
def config(self):
43+
return {"dataset_config": {}}
44+
45+
def __instantiate_transform(self, kwargs):
46+
self._pre_transform = None
47+
self._transform = None
48+
self._train_transform = None
49+
self._val_transform = None
50+
self._test_transform = None
51+
52+
for k in [k for k in kwargs]:
53+
if "transform" in k and kwargs.get(k) is not None:
54+
transforms = []
55+
for t in kwargs.get(k):
56+
if t.get("activate") is not None:
57+
if t.activate is False:
58+
continue
59+
del t["activate"]
60+
transforms.append(instantiate(t))
61+
transform = T.Compose(transforms)
62+
setattr(self, f"_{k}", transform)
63+
del kwargs[k]
64+
65+
@property
66+
def num_features(self):
67+
pass
68+
69+
@property
70+
def num_classes(self):
71+
pass
72+
73+
@property
74+
def hyper_parameters(self):
75+
return {"num_features": self.num_features, "num_classes": self.num_classes}
76+
77+
def prepare_data(self):
78+
pass

lightning_transformers/trainer.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from typing import Any, Optional
2+
3+
import hydra
4+
from omegaconf import DictConfig, OmegaConf
5+
from pytorch_lightning import LightningDataModule
6+
from pytorch_lightning.utilities.distributed import rank_zero_info
7+
8+
from lightning_transformers.core import TaskTransformer, TransformerDataModule
9+
from lightning_transformers.core.config import TaskConfig, TrainerConfig, TransformerDataConfig
10+
from lightning_transformers.core.instantiator import HydraInstantiator, Instantiator
11+
from lightning_transformers.core.nlp.config import HFTokenizerConfig
12+
from lightning_transformers.core.utils import set_ignore_warnings
13+
14+
15+
def run(
16+
instantiator: Instantiator,
17+
ignore_warnings: bool = True,
18+
run_test_after_fit: bool = True,
19+
dataset: TransformerDataConfig = TransformerDataConfig(),
20+
task: TaskConfig = TaskConfig(),
21+
trainer: TrainerConfig = TrainerConfig(),
22+
tokenizer: Optional[HFTokenizerConfig] = None,
23+
logger: Optional[Any] = None,
24+
) -> None:
25+
if ignore_warnings:
26+
set_ignore_warnings()
27+
28+
data_module_kwargs = {}
29+
if tokenizer is not None:
30+
data_module_kwargs["tokenizer"] = tokenizer
31+
32+
data_module: TransformerDataModule = instantiator.data_module(dataset, **data_module_kwargs)
33+
if data_module is None:
34+
raise ValueError("No dataset found. Hydra hint: did you set `dataset=...`?")
35+
if not isinstance(data_module, LightningDataModule):
36+
raise ValueError(
37+
"The instantiator did not return a DataModule instance."
38+
" Hydra hint: is `dataset._target_` defined?`"
39+
)
40+
data_module.setup("fit")
41+
42+
model: TaskTransformer = instantiator.model(task, model_data_kwargs=getattr(data_module, "model_data_kwargs", None))
43+
trainer = instantiator.trainer(
44+
trainer,
45+
logger=logger,
46+
)
47+
48+
trainer.fit(model, datamodule=data_module)
49+
if run_test_after_fit:
50+
trainer.test(model, datamodule=data_module)
51+
52+
53+
def main(cfg: DictConfig) -> None:
54+
rank_zero_info(OmegaConf.to_yaml(cfg))
55+
instantiator = HydraInstantiator()
56+
logger = instantiator.logger(cfg)
57+
run(
58+
instantiator,
59+
ignore_warnings=cfg.get("ignore_warnings"),
60+
run_test_after_fit=cfg.get("training").get("run_test_after_fit"),
61+
dataset=cfg.get("dataset"),
62+
tokenizer=cfg.get("tokenizer"),
63+
task=cfg.get("task"),
64+
trainer=cfg.get("trainer"),
65+
logger=logger,
66+
)
67+
68+
69+
@hydra.main(config_path="../../conf", config_name="config")
70+
def hydra_entry(cfg: DictConfig) -> None:
71+
main(cfg)
72+
73+
74+
if __name__ == "__main__":
75+
hydra_entry()

train.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import hydra
2+
from hydra.core.global_hydra import GlobalHydra
3+
from omegaconf import OmegaConf
4+
from torch_points3d.trainer import Trainer
5+
6+
OmegaConf.register_new_resolver("get_filename", lambda x: x.split('/')[-1])
7+
@hydra.main(config_path="conf", config_name="config")
8+
def main(cfg):
9+
OmegaConf.set_struct(cfg, False) # This allows getattr and hasattr methods to function correctly
10+
if cfg.pretty_print:
11+
print(OmegaConf.to_yaml(cfg))
12+
13+
trainer = Trainer(cfg)
14+
trainer.train()
15+
16+
if __name__ == "__main__":
17+
main()

0 commit comments

Comments
 (0)