1+ import torch
2+ import torch .nn as nn
3+ import torch .nn .functional as F
4+
5+ class Pointnet (nn .Module ):
6+
7+ def __init__ (self , in_channels ,
8+ out_channels ,
9+ hidden_dim ):
10+ super ().__init__ ()
11+
12+ self .fc_in = nn .Conv1d (in_channels , 2 * hidden_dim , 1 )
13+ self .fc_0 = nn .Conv1d (2 * hidden_dim , hidden_dim , 1 )
14+ self .fc_1 = nn .Conv1d (2 * hidden_dim , hidden_dim , 1 )
15+ self .fc_2 = nn .Conv1d (2 * hidden_dim , hidden_dim , 1 )
16+ self .fc_3 = nn .Conv1d (2 * hidden_dim , hidden_dim , 1 )
17+ self .fc_out = nn .Linear (hidden_dim , out_channels , 1 )
18+
19+ self .activation = nn .ReLU ()
20+
21+ def forward (self , x ):
22+
23+ x = self .fc_in (x )
24+
25+ x = self .fc_0 (self .activation (x ))
26+ x_pool = torch .max (x , dim = 2 , keepdim = True )[0 ].expand_as (x )
27+ x = torch .cat ([x , x_pool ], dim = 1 )
28+
29+ x = self .fc_1 (self .activation (x ))
30+ x_pool = torch .max (x , dim = 2 , keepdim = True )[0 ].expand_as (x )
31+ x = torch .cat ([x , x_pool ], dim = 1 )
32+
33+ x = self .fc_2 (self .activation (x ))
34+ x_pool = torch .max (x , dim = 2 , keepdim = True )[0 ].expand_as (x )
35+ x = torch .cat ([x , x_pool ], dim = 1 )
36+
37+ x = self .fc_3 (self .activation (x ))
38+
39+ x = torch .max (x , dim = 2 )[0 ]
40+
41+ x = self .fc_out (x )
42+
43+ return x
44+
45+
46+ class ResidualBlock (nn .Module ):
47+
48+ def __init__ (self , in_channels , out_channels , hidden_dim ):
49+ super ().__init__ ()
50+
51+ # Submodules
52+ self .fc_0 = nn .Conv1d (in_channels , hidden_dim , 1 )
53+ self .fc_1 = nn .Conv1d (hidden_dim , out_channels , 1 )
54+ self .activation = nn .ReLU ()
55+
56+ if in_channels != out_channels :
57+ self .shortcut = nn .Conv1d (in_channels , out_channels ,1 )
58+ else :
59+ self .shortcut = nn .Identity ()
60+
61+ nn .init .zeros_ (self .fc_1 .weight )
62+
63+ def forward (self , x ):
64+ x_short = self .shortcut (x )
65+ x = self .fc_0 (x )
66+ x = self .fc_1 (self .activation (x ))
67+ x = self .activation (x + x_short )
68+ return x
69+
70+
71+
72+ class ResidualPointnet (nn .Module ):
73+ ''' PointNet-based encoder network with ResNet blocks.
74+ Args:
75+ c_dim (int): dimension of latent code c
76+ dim (int): input points dimension
77+ hidden_dim (int): hidden dimension of the network
78+ '''
79+
80+ def __init__ (self , in_channels , out_channels , hidden_dim ):
81+ super ().__init__ ()
82+
83+ self .fc_in = nn .Conv1d (in_channels , 2 * hidden_dim , 1 )
84+ self .block_0 = ResidualBlock (2 * hidden_dim , hidden_dim , hidden_dim )
85+ self .block_1 = ResidualBlock (2 * hidden_dim , hidden_dim , hidden_dim )
86+ self .block_2 = ResidualBlock (2 * hidden_dim , hidden_dim , hidden_dim )
87+ self .block_3 = ResidualBlock (2 * hidden_dim , hidden_dim , hidden_dim )
88+ self .block_4 = ResidualBlock (2 * hidden_dim , hidden_dim , hidden_dim )
89+ self .fc_out = nn .Linear (hidden_dim , out_channels )
90+
91+
92+ def forward (self , x ):
93+
94+ x = self .fc_in (x )
95+
96+ x = self .block_0 (x )
97+ x_pool = torch .max (x , dim = 2 , keepdim = True )[0 ].expand_as (x )
98+ x = torch .cat ([x , x_pool ], dim = 1 )
99+
100+ x = self .block_1 (x )
101+ x_pool = torch .max (x , dim = 2 , keepdim = True )[0 ].expand_as (x )
102+ x = torch .cat ([x , x_pool ], dim = 1 )
103+
104+ x = self .block_2 (x )
105+ x_pool = torch .max (x , dim = 2 , keepdim = True )[0 ].expand_as (x )
106+ x = torch .cat ([x , x_pool ], dim = 1 )
107+
108+ x = self .block_3 (x )
109+ x_pool = torch .max (x , dim = 2 , keepdim = True )[0 ].expand_as (x )
110+ x = torch .cat ([x , x_pool ], dim = 1 )
111+
112+ x = self .block_4 (x )
113+
114+ x = torch .max (x , dim = 2 )[0 ]
115+
116+ x = self .fc_out (x )
117+
118+ return x
0 commit comments