@@ -6,15 +6,21 @@ class Pointnet(nn.Module):
66
77 def __init__ (self , in_channels ,
88 out_channels ,
9- hidden_dim ):
9+ hidden_dim , segmentation = False ):
1010 super ().__init__ ()
1111
1212 self .fc_in = nn .Conv1d (in_channels , 2 * hidden_dim , 1 )
1313 self .fc_0 = nn .Conv1d (2 * hidden_dim , hidden_dim , 1 )
1414 self .fc_1 = nn .Conv1d (2 * hidden_dim , hidden_dim , 1 )
1515 self .fc_2 = nn .Conv1d (2 * hidden_dim , hidden_dim , 1 )
1616 self .fc_3 = nn .Conv1d (2 * hidden_dim , hidden_dim , 1 )
17- self .fc_out = nn .Linear (hidden_dim , out_channels , 1 )
17+
18+ self .segmentation = segmentation
19+
20+ if segmentation :
21+ self .fc_out = nn .Conv1d (2 * hidden_dim , out_channels , 1 )
22+ else :
23+ self .fc_out = nn .Linear (hidden_dim , out_channels )
1824
1925 self .activation = nn .ReLU ()
2026
@@ -36,8 +42,12 @@ def forward(self, x):
3642
3743 x = self .fc_3 (self .activation (x ))
3844
39- x = torch .max (x , dim = 2 )[0 ]
40-
45+ if self .segmentation :
46+ x_pool = torch .max (x , dim = 2 , keepdim = True )[0 ].expand_as (x )
47+ x = torch .cat ([x , x_pool ], dim = 1 )
48+ else :
49+ x = torch .max (x , dim = 2 )[0 ]
50+
4151 x = self .fc_out (x )
4252
4353 return x
@@ -77,7 +87,7 @@ class ResidualPointnet(nn.Module):
7787 hidden_dim (int): hidden dimension of the network
7888 '''
7989
80- def __init__ (self , in_channels , out_channels , hidden_dim ):
90+ def __init__ (self , in_channels , out_channels , hidden_dim , segmentation = False ):
8191 super ().__init__ ()
8292
8393 self .fc_in = nn .Conv1d (in_channels , 2 * hidden_dim , 1 )
@@ -86,7 +96,12 @@ def __init__(self, in_channels, out_channels, hidden_dim):
8696 self .block_2 = ResidualBlock (2 * hidden_dim , hidden_dim , hidden_dim )
8797 self .block_3 = ResidualBlock (2 * hidden_dim , hidden_dim , hidden_dim )
8898 self .block_4 = ResidualBlock (2 * hidden_dim , hidden_dim , hidden_dim )
89- self .fc_out = nn .Linear (hidden_dim , out_channels )
99+
100+ self .segmentation = segmentation
101+ if self .segmentation :
102+ self .fc_out = nn .Conv1d (2 * hidden_dim , out_channels , 1 )
103+ else :
104+ self .fc_out = nn .Linear (hidden_dim , out_channels )
90105
91106
92107 def forward (self , x ):
@@ -111,7 +126,11 @@ def forward(self, x):
111126
112127 x = self .block_4 (x )
113128
114- x = torch .max (x , dim = 2 )[0 ]
129+ if self .segmentation :
130+ x_pool = torch .max (x , dim = 2 , keepdim = True )[0 ].expand_as (x )
131+ x = torch .cat ([x , x_pool ], dim = 1 )
132+ else :
133+ x = torch .max (x , dim = 2 )[0 ]
115134
116135 x = self .fc_out (x )
117136
0 commit comments