Skip to content

Commit d5f96b7

Browse files
committed
add criterion
1 parent aec17f1 commit d5f96b7

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

conf/model/segmentation/default.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ model:
66
_recursive_: false
77
_target_: torch_points3d.models.segmentation.base_model.SegmentationBaseModel
88
num_classes: ${dataset.cfg.num_classes}
9+
criterion:
10+
_target_: torch.nn.NLLLoss
911

1012
backbone:
1113
input_nc: ${dataset.cfg.feature_dimension}

torch_points3d/models/segmentation/base_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88

99

1010
class SegmentationBaseModel(PointCloudBaseModel):
11-
def __init__(self, instantiator: Instantiator, num_classes: int, backbone: DictConfig):
11+
def __init__(self, instantiator: Instantiator, num_classes: int, backbone: DictConfig, criterion: DictConfig):
1212
super().__init__(instantiator)
1313

1414
print(backbone)
1515
self.backbone = self.instantiator.backbone(backbone)
16+
self.criterion = self.instantiator.instantiate(criterion)
1617

1718
self.head = nn.Sequential(nn.Linear(self.backbone.output_nc, num_classes))
1819

0 commit comments

Comments
 (0)