Skip to content

Commit a837f2c

Browse files
committed
Update package manager to poetry.
1 parent 3aef8af commit a837f2c

File tree

12 files changed

+2322
-1192
lines changed

12 files changed

+2322
-1192
lines changed

.clang-format

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
BasedOnStyle: Google
2+
3+
ColumnLimit: 120
4+
IndentWidth: 4
5+
Standard: Cpp11
6+
AccessModifierOffset: -4
7+
IndentCaseLabels: false
8+
9+
AlignAfterOpenBracket: Align
10+
AlignConsecutiveAssignments: false
11+
AlignTrailingComments: true
12+
AlignOperands: true
13+
AllowAllParametersOfDeclarationOnNextLine: false
14+
AllowAllConstructorInitializersOnNextLine: false
15+
AllowShortBlocksOnASingleLine: false
16+
AllowShortFunctionsOnASingleLine: false
17+
BinPackArguments: true
18+
BreakConstructorInitializers: BeforeColon
19+
BreakConstructorInitializersBeforeComma: true
20+
Cpp11BracedListStyle: false
21+
ConstructorInitializerAllOnOneLineOrOnePerLine: false
22+
DerivePointerAlignment: false
23+
MaxEmptyLinesToKeep: 1
24+
PointerAlignment: Right
25+
ReflowComments: false
26+
SortIncludes: false
27+
28+
UseTab: Never

build_cxx.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import platform
2+
3+
from setuptools import setup
4+
from setuptools.errors import CCompilerError, PackageDiscoveryError
5+
from torch.utils.cpp_extension import BuildExtension
6+
7+
8+
class MyBuildExtension(BuildExtension):
9+
def run(self):
10+
try:
11+
super(MyBuildExtension, self).run()
12+
except FileNotFoundError:
13+
raise Exception("File not found. Could not compile C extension")
14+
15+
def build_extension(self, ext):
16+
# common settings
17+
for e in self.extensions:
18+
pass
19+
20+
# OS specific settings
21+
if platform.system() == "Darwin":
22+
for e in self.extensions:
23+
e.extra_compile_args.extend([
24+
"-Xpreprocessor",
25+
"-fopenmp",
26+
"-mmacosx-version-min=10.15",
27+
])
28+
e.extra_link_args.extend([
29+
"-lomp",
30+
])
31+
32+
elif platform.system() == "Linux":
33+
for e in self.extensions:
34+
e.extra_compile_args.extend([
35+
"-fopenmp",
36+
])
37+
e.extra_link_args.extend([
38+
"-fopenmp",
39+
])
40+
41+
# compiler specific settings
42+
if self.compiler.compiler_type == "unix":
43+
for e in self.extensions:
44+
e.extra_compile_args.extend([
45+
"-std=c++17",
46+
"-pthread",
47+
])
48+
49+
elif self.compiler.compiler_type == "msvc":
50+
for e in self.extensions:
51+
e.extra_compile_args.extend(["/utf-8", "/std:c++17", "/openmp"])
52+
e.define_macros.extend([
53+
("_CRT_SECURE_NO_WARNINGS", 1),
54+
("_SILENCE_EXPERIMENTAL_FILESYSTEM_DEPRECATION_WARNING", 1),
55+
])
56+
57+
# building
58+
try:
59+
super(MyBuildExtension, self).build_extension(ext)
60+
except (CCompilerError, PackageDiscoveryError, ValueError):
61+
raise Exception("Could not compile C extension")
62+
63+
64+
def build(setup_kwargs):
65+
from torch.utils.cpp_extension import CUDAExtension
66+
67+
try:
68+
setup_kwargs.update({
69+
"ext_modules": [
70+
CUDAExtension(
71+
'torchmcubes_module',
72+
[
73+
'cxx/pscan.cu',
74+
'cxx/mcubes.cpp',
75+
'cxx/mcubes_cpu.cpp',
76+
'cxx/mcubes_cuda.cu',
77+
'cxx/grid_interp_cpu.cpp',
78+
'cxx/grid_interp_cuda.cu',
79+
],
80+
extra_compile_args=['-DWITH_CUDA'],
81+
)
82+
],
83+
"cmdclass": {
84+
"build_ext": BuildExtension
85+
}
86+
})
87+
except:
88+
from torch.utils.cpp_extension import CppExtension
89+
90+
print('CUDA environment was not successfully loaded!')
91+
print('Build only CPU module!')
92+
93+
setup_kwargs.update({
94+
'ext_modules': [
95+
CppExtension('torchmcubes_module', [
96+
'cxx/mcubes.cpp',
97+
'cxx/mcubes_cpu.cpp',
98+
'cxx/grid_interp_cpu.cpp',
99+
])
100+
],
101+
'cmdclass': {
102+
'build_ext': MyBuildExtension
103+
}
104+
})

cxx/grid_interp_cpu.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,15 @@ torch::Tensor grid_interp_cpu(torch::Tensor vol, torch::Tensor points) {
2525
const int C = vol.size(0);
2626
const int Np = points.size(0);
2727

28-
torch::Tensor output = torch::zeros({Np, C},
29-
torch::TensorOptions().dtype(torch::kFloat32).device(vol.device()));
28+
torch::Tensor output = torch::zeros({ Np, C }, torch::TensorOptions().dtype(torch::kFloat32).device(vol.device()));
3029

3130
auto vol_ascr = vol.accessor<float, 4>();
3231
auto pts_ascr = points.accessor<float, 2>();
3332
auto out_ascr = output.accessor<float, 2>();
3433

35-
#ifdef _OPENMP
36-
#pragma omp parallel for
37-
#endif
34+
#ifdef _OPENMP
35+
#pragma omp parallel for
36+
#endif
3837
for (int i = 0; i < Np; i++) {
3938
const float x = pts_ascr[i][0];
4039
const float y = pts_ascr[i][1];
@@ -59,12 +58,12 @@ torch::Tensor grid_interp_cpu(torch::Tensor vol, torch::Tensor points) {
5958
const float v01 = (1.0 - fx) * vol_ascr[c][z0][y1][x0] + fx * vol_ascr[c][z0][y1][x1];
6059
const float v10 = (1.0 - fx) * vol_ascr[c][z1][y0][x0] + fx * vol_ascr[c][z1][y0][x1];
6160
const float v11 = (1.0 - fx) * vol_ascr[c][z1][y1][x0] + fx * vol_ascr[c][z1][y1][x1];
62-
61+
6362
const float v0 = (1.0 - fy) * v00 + fy * v01;
6463
const float v1 = (1.0 - fy) * v10 + fy * v11;
6564

6665
out_ascr[i][c] = (1.0 - fz) * v0 + fz * v1;
67-
}
66+
}
6867
}
6968

7069
return output;

0 commit comments

Comments
 (0)