Skip to content

Commit 826ca5c

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 1618ffa commit 826ca5c

File tree

3 files changed

+7
-19
lines changed

3 files changed

+7
-19
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
[tool.black]
2-
line-length = 120
2+
line-length = 120

torch_points3d/core/instantiator.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,7 @@ def model(
4747
) -> "TaskTransformer":
4848
if model_data_kwargs is None:
4949
model_data_kwargs = {}
50-
model_data_kwargs = dict(
51-
model_data_kwargs
52-
) # avoid ConfigKeyError: Key 'tokenizer' is not in struct`
50+
model_data_kwargs = dict(model_data_kwargs) # avoid ConfigKeyError: Key 'tokenizer' is not in struct`
5351

5452
# use `model_data_kwargs` to pass `tokenizer` and `pipeline_kwargs`
5553
# as not all models might contain these parameters.
@@ -60,33 +58,25 @@ def model(
6058

6159
return self.instantiate(cfg, instantiator=self, **model_data_kwargs)
6260

63-
def optimizer(
64-
self, model: torch.nn.Module, cfg: DictConfig
65-
) -> torch.optim.Optimizer:
61+
def optimizer(self, model: torch.nn.Module, cfg: DictConfig) -> torch.optim.Optimizer:
6662
no_decay = ["bias", "LayerNorm.weight"]
6763
grouped_parameters = [
6864
{
6965
"params": [
70-
p
71-
for n, p in model.named_parameters()
72-
if not any(nd in n for nd in no_decay) and p.requires_grad
66+
p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad
7367
],
7468
"weight_decay": cfg.weight_decay,
7569
},
7670
{
7771
"params": [
78-
p
79-
for n, p in model.named_parameters()
80-
if any(nd in n for nd in no_decay) and p.requires_grad
72+
p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad
8173
],
8274
"weight_decay": 0.0,
8375
},
8476
]
8577
return self.instantiate(cfg, grouped_parameters)
8678

87-
def scheduler(
88-
self, cfg: DictConfig, optimizer: torch.optim.Optimizer
89-
) -> torch.optim.lr_scheduler._LRScheduler:
79+
def scheduler(self, cfg: DictConfig, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler:
9080
return self.instantiate(cfg, optimizer=optimizer)
9181

9282
def data_module(

train.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@
88

99
@hydra.main(config_path="conf", config_name="config")
1010
def main(cfg):
11-
OmegaConf.set_struct(
12-
cfg, False
13-
) # This allows getattr and hasattr methods to function correctly
11+
OmegaConf.set_struct(cfg, False) # This allows getattr and hasattr methods to function correctly
1412
if cfg.pretty_print:
1513
print(OmegaConf.to_yaml(cfg))
1614

0 commit comments

Comments
 (0)