Skip to content

Commit ef3c512

Browse files
committed
Adding Pointnet classification networks
1 parent 139619a commit ef3c512

File tree

1 file changed

+118
-0
lines changed

1 file changed

+118
-0
lines changed
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
class Pointnet(nn.Module):
6+
7+
def __init__(self, in_channels,
8+
out_channels,
9+
hidden_dim):
10+
super().__init__()
11+
12+
self.fc_in = nn.Conv1d(in_channels, 2*hidden_dim, 1)
13+
self.fc_0 = nn.Conv1d(2*hidden_dim, hidden_dim, 1)
14+
self.fc_1 = nn.Conv1d(2*hidden_dim, hidden_dim, 1)
15+
self.fc_2 = nn.Conv1d(2*hidden_dim, hidden_dim, 1)
16+
self.fc_3 = nn.Conv1d(2*hidden_dim, hidden_dim, 1)
17+
self.fc_out = nn.Linear(hidden_dim, out_channels, 1)
18+
19+
self.activation = nn.ReLU()
20+
21+
def forward(self, x):
22+
23+
x = self.fc_in(x)
24+
25+
x = self.fc_0(self.activation(x))
26+
x_pool = torch.max(x, dim=2, keepdim=True)[0].expand_as(x)
27+
x = torch.cat([x, x_pool], dim=1)
28+
29+
x = self.fc_1(self.activation(x))
30+
x_pool = torch.max(x, dim=2, keepdim=True)[0].expand_as(x)
31+
x = torch.cat([x, x_pool], dim=1)
32+
33+
x = self.fc_2(self.activation(x))
34+
x_pool = torch.max(x, dim=2, keepdim=True)[0].expand_as(x)
35+
x = torch.cat([x, x_pool], dim=1)
36+
37+
x = self.fc_3(self.activation(x))
38+
39+
x = torch.max(x, dim=2)[0]
40+
41+
x = self.fc_out(x)
42+
43+
return x
44+
45+
46+
class ResidualBlock(nn.Module):
47+
48+
def __init__(self, in_channels, out_channels, hidden_dim):
49+
super().__init__()
50+
51+
# Submodules
52+
self.fc_0 = nn.Conv1d(in_channels, hidden_dim, 1)
53+
self.fc_1 = nn.Conv1d(hidden_dim, out_channels, 1)
54+
self.activation = nn.ReLU()
55+
56+
if in_channels != out_channels:
57+
self.shortcut = nn.Conv1d(in_channels, out_channels,1)
58+
else:
59+
self.shortcut = nn.Identity()
60+
61+
nn.init.zeros_(self.fc_1.weight)
62+
63+
def forward(self, x):
64+
x_short = self.shortcut(x)
65+
x = self.fc_0(x)
66+
x = self.fc_1(self.activation(x))
67+
x = self.activation(x + x_short)
68+
return x
69+
70+
71+
72+
class ResidualPointnet(nn.Module):
73+
''' PointNet-based encoder network with ResNet blocks.
74+
Args:
75+
c_dim (int): dimension of latent code c
76+
dim (int): input points dimension
77+
hidden_dim (int): hidden dimension of the network
78+
'''
79+
80+
def __init__(self, in_channels, out_channels, hidden_dim):
81+
super().__init__()
82+
83+
self.fc_in = nn.Conv1d(in_channels, 2*hidden_dim, 1)
84+
self.block_0 = ResidualBlock(2*hidden_dim, hidden_dim, hidden_dim)
85+
self.block_1 = ResidualBlock(2*hidden_dim, hidden_dim, hidden_dim)
86+
self.block_2 = ResidualBlock(2*hidden_dim, hidden_dim, hidden_dim)
87+
self.block_3 = ResidualBlock(2*hidden_dim, hidden_dim, hidden_dim)
88+
self.block_4 = ResidualBlock(2*hidden_dim, hidden_dim, hidden_dim)
89+
self.fc_out = nn.Linear(hidden_dim, out_channels)
90+
91+
92+
def forward(self, x):
93+
94+
x = self.fc_in(x)
95+
96+
x = self.block_0(x)
97+
x_pool = torch.max(x, dim=2, keepdim=True)[0].expand_as(x)
98+
x = torch.cat([x, x_pool], dim=1)
99+
100+
x = self.block_1(x)
101+
x_pool = torch.max(x, dim=2, keepdim=True)[0].expand_as(x)
102+
x = torch.cat([x, x_pool], dim=1)
103+
104+
x = self.block_2(x)
105+
x_pool = torch.max(x, dim=2, keepdim=True)[0].expand_as(x)
106+
x = torch.cat([x, x_pool], dim=1)
107+
108+
x = self.block_3(x)
109+
x_pool = torch.max(x, dim=2, keepdim=True)[0].expand_as(x)
110+
x = torch.cat([x, x_pool], dim=1)
111+
112+
x = self.block_4(x)
113+
114+
x = torch.max(x, dim=2)[0]
115+
116+
x = self.fc_out(x)
117+
118+
return x

0 commit comments

Comments
 (0)