Skip to content

Commit 7deb404

Browse files
committed
update
1 parent d5f96b7 commit 7deb404

File tree

4 files changed

+26
-10
lines changed

4 files changed

+26
-10
lines changed

conf/dataset/default.yaml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
# @package dataset
2-
# cfg:
3-
# torch data-loader specific arguments
42
cfg:
5-
num_classes:
63
feature_dimension:
74
batch_size: ${training.batch_size}
85
num_workers: ${training.num_workers}
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
# @package dataset
22
defaults:
3-
- /dataset/default
3+
- /dataset/default
4+
5+
cfg:
6+
num_classes:

torch_points3d/models/base_model.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1+
from typing import Dict, Union
2+
3+
import torch
14
import torch.nn as nn
5+
from torch_geometric.data import Data
26

37
from torch_points3d.core.instantiator import Instantiator
48

@@ -9,8 +13,11 @@ def __init__(self, instantiator: Instantiator):
913

1014
self.instantiator = instantiator
1115

12-
def set_input(self, data):
16+
def set_input(self, data: Data) -> None:
1317
raise (NotImplementedError("set_input needs to be defined!"))
1418

15-
def forward(self):
19+
def forward(self) -> Union[torch.Tensor, None]:
1620
raise (NotImplementedError("forward needs to be defined!"))
21+
22+
def get_losses(self) -> Union[torch.Tensor, None]:
23+
raise (NotImplementedError("get_losses needs to be defined!"))

torch_points3d/models/segmentation/base_model.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from omegaconf import DictConfig
2+
from typing import Dict, Union
23

4+
import torch
35
import torch.nn as nn
46
import torch.nn.functional as F
7+
from torch_geometric.data import Data
58

69
from torch_points3d.core.instantiator import Instantiator
710
from torch_points3d.models.base_model import PointCloudBaseModel
@@ -17,19 +20,25 @@ def __init__(self, instantiator: Instantiator, num_classes: int, backbone: DictC
1720

1821
self.head = nn.Sequential(nn.Linear(self.backbone.output_nc, num_classes))
1922

20-
def set_input(self, data):
23+
def set_input(self, data: Data) -> None:
2124
self.batch_idx = data.batch.squeeze()
2225
self.input = data
2326
if data.y is not None:
2427
self.labels = data.y
2528
else:
2629
self.labels = None
2730

28-
def forward(self):
31+
def forward(self) -> Union[torch.Tensor, None]:
2932
features = self.backbone(self.input).x
3033
logits = self.head(features)
3134
self.output = F.log_softmax(logits, dim=-1)
3235

36+
return self.get_losses()
37+
38+
def get_losses(self) -> Union[torch.Tensor, None]:
3339
# only compute loss if loss is defined and the dset has labels
34-
if self.labels is not None and self.criterion is not None:
35-
return self.criterion(self.output, self.labels)
40+
if self.labels is None or self.criterion is None:
41+
return
42+
43+
self.loss = self.criterion(self.output, self.labels)
44+
return self.loss

0 commit comments

Comments
 (0)