Skip to content

Commit 8935496

Browse files
committed
networks - add pointnet segmentation
1 parent ef3c512 commit 8935496

File tree

1 file changed

+26
-7
lines changed

1 file changed

+26
-7
lines changed

lightconvpoint/networks/pointnet.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,21 @@ class Pointnet(nn.Module):
66

77
def __init__(self, in_channels,
88
out_channels,
9-
hidden_dim):
9+
hidden_dim, segmentation=False):
1010
super().__init__()
1111

1212
self.fc_in = nn.Conv1d(in_channels, 2*hidden_dim, 1)
1313
self.fc_0 = nn.Conv1d(2*hidden_dim, hidden_dim, 1)
1414
self.fc_1 = nn.Conv1d(2*hidden_dim, hidden_dim, 1)
1515
self.fc_2 = nn.Conv1d(2*hidden_dim, hidden_dim, 1)
1616
self.fc_3 = nn.Conv1d(2*hidden_dim, hidden_dim, 1)
17-
self.fc_out = nn.Linear(hidden_dim, out_channels, 1)
17+
18+
self.segmentation=segmentation
19+
20+
if segmentation:
21+
self.fc_out = nn.Conv1d(2*hidden_dim, out_channels, 1)
22+
else:
23+
self.fc_out = nn.Linear(hidden_dim, out_channels)
1824

1925
self.activation = nn.ReLU()
2026

@@ -36,8 +42,12 @@ def forward(self, x):
3642

3743
x = self.fc_3(self.activation(x))
3844

39-
x = torch.max(x, dim=2)[0]
40-
45+
if self.segmentation:
46+
x_pool = torch.max(x, dim=2, keepdim=True)[0].expand_as(x)
47+
x = torch.cat([x, x_pool], dim=1)
48+
else:
49+
x = torch.max(x, dim=2)[0]
50+
4151
x = self.fc_out(x)
4252

4353
return x
@@ -77,7 +87,7 @@ class ResidualPointnet(nn.Module):
7787
hidden_dim (int): hidden dimension of the network
7888
'''
7989

80-
def __init__(self, in_channels, out_channels, hidden_dim):
90+
def __init__(self, in_channels, out_channels, hidden_dim, segmentation=False):
8191
super().__init__()
8292

8393
self.fc_in = nn.Conv1d(in_channels, 2*hidden_dim, 1)
@@ -86,7 +96,12 @@ def __init__(self, in_channels, out_channels, hidden_dim):
8696
self.block_2 = ResidualBlock(2*hidden_dim, hidden_dim, hidden_dim)
8797
self.block_3 = ResidualBlock(2*hidden_dim, hidden_dim, hidden_dim)
8898
self.block_4 = ResidualBlock(2*hidden_dim, hidden_dim, hidden_dim)
89-
self.fc_out = nn.Linear(hidden_dim, out_channels)
99+
100+
self.segmentation = segmentation
101+
if self.segmentation:
102+
self.fc_out = nn.Conv1d(2*hidden_dim, out_channels, 1)
103+
else:
104+
self.fc_out = nn.Linear(hidden_dim, out_channels)
90105

91106

92107
def forward(self, x):
@@ -111,7 +126,11 @@ def forward(self, x):
111126

112127
x = self.block_4(x)
113128

114-
x = torch.max(x, dim=2)[0]
129+
if self.segmentation:
130+
x_pool = torch.max(x, dim=2, keepdim=True)[0].expand_as(x)
131+
x = torch.cat([x, x_pool], dim=1)
132+
else:
133+
x = torch.max(x, dim=2)[0]
115134

116135
x = self.fc_out(x)
117136

0 commit comments

Comments
 (0)