Skip to content

Commit abc4db9

Browse files
knn interpolate dense
1 parent 4cf6eff commit abc4db9

File tree

6 files changed

+59
-8
lines changed

6 files changed

+59
-8
lines changed

.devcontainer/Dockerfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,4 @@ RUN apt-get update \
3030

3131
RUN pip3 install -U pip
3232
RUN pip3 install torch numpy scikit-learn flake8 setuptools
33+
RUN pip3 install torch_cluster torch_sparse torch_scatter torch_geometric

.devcontainer/devcontainer.json

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,23 @@
44
"name": "Python 3",
55
"context": "..",
66
"dockerFile": "Dockerfile",
7-
87
// Set *default* container specific settings.json values on container create.
9-
"settings": {
8+
"settings": {
109
"terminal.integrated.shell.linux": "/bin/bash",
1110
"python.pythonPath": "/usr/local/bin/python",
1211
"python.linting.enabled": true,
1312
"python.linting.pylintEnabled": true,
1413
"python.linting.pylintPath": "/usr/local/bin/pylint"
1514
},
16-
1715
// Add the IDs of extensions you want installed when the container is created.
1816
"extensions": [
19-
"ms-python.python"
17+
"ms-python.python",
18+
"ms-vscode.cpptools"
2019
]
21-
2220
// Use 'forwardPorts' to make a list of ports inside the container available locally.
2321
// "forwardPorts": [],
24-
2522
// Use 'postCreateCommand' to run commands after the container is created.
2623
// "postCreateCommand": "pip install -r requirements.txt",
27-
2824
// Uncomment to connect as a non-root user. See https://aka.ms/vscode-remote/containers/non-root.
2925
// "remoteUser": "vscode"
30-
}
26+
}

test/test_interpolate.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import unittest
2+
import torch
3+
from torch_points import three_interpolate_tg, three_interpolate, three_nn
4+
5+
class TestInterpolate(unittest.TestCase):
6+
def test_cpu(self):
7+
pos = torch.randn([16, 100, 3])
8+
pos_skip = torch.randn([16, 500, 3])
9+
x = torch.randn([16, 30, 100])
10+
11+
# # dense
12+
# dist, idx = three_nn(pos_skip, pos)
13+
# dist_recip = 1.0 / (dist + 1e-8)
14+
# norm = torch.sum(dist_recip, dim=2, keepdim=True)
15+
# weight = dist_recip / norm
16+
# interpolated_feats = three_interpolate(x, idx, weight)
17+
18+
# sparse
19+
sp_interpolated = three_interpolate_tg(x,pos,pos_skip)
20+
21+
# torch.testing.assert_allclose(sp_interpolated, interpolated_feats)
22+
23+
if __name__ == "__main__":
24+
unittest.main()

torch_points/knn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,6 @@ def knn(pos_support, pos, k):
1111
idx - [B,M,k]
1212
dist2 - [B,M,k] squared distances
1313
"""
14+
assert pos_support.dim() == 3 and pos.dim() == 3
15+
1416
return tpcpu.dense_knn(pos_support, pos, k)

torch_points/knn_interpolate.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import torch
2+
from .knn import knn
3+
4+
def knn_interpolate(x, pos, pos_support, k):
5+
""" KNN interpolation for dense data
6+
7+
Parameters
8+
----------
9+
x : (B, C, n) tensor of known features
10+
pos : (B, n, 3) tensor of positions of known features
11+
pos_support : (B, m, 3) tensor of position of unknown features (generally m > n)
12+
13+
Returns
14+
-------
15+
(B, C, m) interpolated features
16+
"""
17+
18+
knn_idx, knn_dist = knn(pos_support, pos, k)
19+
20+

torch_points/torchpoints.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch.nn as nn
44
import sys
55
from typing import Optional, Any, Tuple
6+
from torch_geometric.nn import knn_interpolate
67

78
import torch_points.points_cpu as tpcpu
89

@@ -141,6 +142,13 @@ def three_interpolate(features, idx, weight):
141142
"""
142143
return ThreeInterpolate.apply(features, idx, weight)
143144

145+
def three_interpolate_tg(x, pos, new_pos):
146+
interpolated = []
147+
for i in range(x.shape[0]):
148+
interpolated.append(knn_interpolate(x.transpose(1,0), pos, new_pos).transpose(1,0))
149+
return torch.stack(interpolated)
150+
151+
144152
def grouping_operation(features, idx):
145153
r"""
146154
Parameters

0 commit comments

Comments
 (0)