Commit 3344791
authored
Triton neighbor list implementation (#373)
* cleaned up installation
* keep just brute tests
* add triton dep
* first triton and pytorch implementations
* fix assertion error
* fixed last issue
* less computations
* reorganized code. added first cell implementation
* upd
* fixed all tests except one
* added working cell implementation
* working with larger block_atoms
* more efficient cell
* update the benchmark suite
* shared triton implementation
* update issue in benchmark
* cleanup
* cleanup
* fix benchmark
* cell implementation closer to CUDA
* use a while loop instead of breaking which doesn't work in triton
* better printing
* initial sorted cell list impl
* fix benchmark printing
* nearly working cell impl
* wip
* wip
* different cell impl
* one more cell implementation
* cuda graph comp
* tiled version
* another impl
* faster version
* memory coalesced cell neighbor impl
* cleanup and keep just the last cell version
* removing shared memory implementation
* cleanup and file headers
* remove CUDA implementations
* missing function
* fix for torch script
* updating installation isntructions
* making triton optional
* install different triton package on windows and none on OSX
* simplify CI deployment and testing
* try without lock file
* fix for flake?
* fix python version
* don't use cuda on ARM machines
* don't try except in compilable code
* no triton on aarch64
* cannot use delayed imports with torchscript
* add ase as a dep
* unfreeze torch version
* fix the OSX issue with MPS not supporting float64
* added test for scripting, then compiling
* fix cuda graphing of torchscripted models. update tests
* restore script+compile test
* get rid of setup_for_compile_cudagraphs
* fix test warnings
* undo some changes to benchmarks
* rename caffeine
* calculators should warmup before recompiling
* catch in output_modules also the case where we are compiling
* int32 dtype for neighbor list and num_pairs
* added test for ASE calculator
* no need to trigger compilation anymore
* no need to trigger compilation
* skip cuda test if no cuda available
* skip on windows due to missing compiler
* prevent triton recompilation with changing number of atoms and cutoffs
* use triton_wrap for compatibility with more pytorch features
* make scatter compilable, make box a registered buffer of OptimizedDistance
* fix backwards compatibility
* remove constraint inserted for exporting
* undo
* revert change to scatter
* cleanup
* optimized the pytorch brute neighborlist implementation to not do O(n^2) but O(n^2/2) computations and mem usage
* simplify
* changing the neighbor arrays from torch.int32 to torch.long had a significant performance boost1 parent 75b16e6 commit 3344791
File tree
36 files changed
+1498
-1993
lines changed- .github/workflows
- benchmarks
- cibuildwheel_support
- docs/source
- examples/aceff_examples
- tests
- torchmdnet
- extensions
- neighbors
- models
36 files changed
+1498
-1993
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
6 | 6 | | |
7 | 7 | | |
8 | 8 | | |
9 | | - | |
10 | | - | |
11 | | - | |
12 | | - | |
13 | | - | |
14 | | - | |
15 | | - | |
16 | | - | |
17 | | - | |
18 | | - | |
19 | | - | |
20 | | - | |
21 | | - | |
22 | | - | |
23 | | - | |
24 | | - | |
25 | | - | |
26 | | - | |
27 | | - | |
28 | | - | |
| 9 | + | |
| 10 | + | |
29 | 11 | | |
30 | 12 | | |
31 | | - | |
32 | | - | |
33 | | - | |
34 | | - | |
35 | | - | |
36 | | - | |
37 | | - | |
38 | | - | |
39 | | - | |
40 | | - | |
41 | | - | |
42 | | - | |
| 13 | + | |
43 | 14 | | |
44 | | - | |
| 15 | + | |
45 | 16 | | |
46 | 17 | | |
47 | 18 | | |
48 | 19 | | |
49 | | - | |
50 | | - | |
51 | | - | |
52 | | - | |
53 | | - | |
54 | | - | |
55 | | - | |
| 20 | + | |
56 | 21 | | |
57 | | - | |
58 | | - | |
59 | | - | |
60 | | - | |
61 | | - | |
62 | | - | |
63 | | - | |
64 | | - | |
65 | | - | |
66 | | - | |
67 | | - | |
68 | | - | |
69 | | - | |
70 | | - | |
71 | | - | |
72 | | - | |
| 22 | + | |
| 23 | + | |
73 | 24 | | |
74 | 25 | | |
75 | 26 | | |
76 | | - | |
77 | | - | |
| 27 | + | |
| 28 | + | |
78 | 29 | | |
79 | | - | |
| 30 | + | |
80 | 31 | | |
81 | 32 | | |
82 | 33 | | |
83 | 34 | | |
84 | 35 | | |
85 | 36 | | |
86 | 37 | | |
| 38 | + | |
87 | 39 | | |
88 | 40 | | |
89 | | - | |
90 | | - | |
91 | | - | |
92 | | - | |
93 | 41 | | |
94 | 42 | | |
95 | 43 | | |
96 | 44 | | |
97 | 45 | | |
98 | | - | |
99 | 46 | | |
100 | 47 | | |
101 | 48 | | |
102 | 49 | | |
103 | 50 | | |
104 | 51 | | |
105 | 52 | | |
106 | | - | |
107 | | - | |
108 | | - | |
109 | | - | |
110 | | - | |
111 | | - | |
112 | | - | |
113 | | - | |
114 | | - | |
115 | | - | |
116 | | - | |
117 | | - | |
118 | | - | |
119 | | - | |
120 | | - | |
121 | | - | |
122 | | - | |
123 | | - | |
124 | | - | |
125 | | - | |
126 | | - | |
127 | | - | |
128 | | - | |
129 | | - | |
130 | | - | |
131 | | - | |
132 | | - | |
133 | | - | |
134 | | - | |
135 | | - | |
136 | | - | |
137 | | - | |
138 | | - | |
139 | | - | |
140 | | - | |
141 | | - | |
142 | | - | |
143 | | - | |
144 | | - | |
145 | | - | |
146 | | - | |
147 | | - | |
148 | | - | |
149 | | - | |
150 | | - | |
151 | | - | |
152 | | - | |
153 | | - | |
154 | | - | |
155 | | - | |
156 | | - | |
157 | | - | |
158 | | - | |
159 | | - | |
160 | | - | |
161 | | - | |
162 | | - | |
163 | | - | |
164 | | - | |
165 | | - | |
166 | | - | |
167 | | - | |
168 | | - | |
| 53 | + | |
169 | 54 | | |
170 | 55 | | |
171 | 56 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
17 | 17 | | |
18 | 18 | | |
19 | 19 | | |
20 | | - | |
21 | | - | |
22 | | - | |
23 | | - | |
24 | 20 | | |
25 | 21 | | |
26 | | - | |
27 | | - | |
28 | | - | |
29 | | - | |
30 | | - | |
31 | | - | |
32 | | - | |
33 | | - | |
| 22 | + | |
34 | 23 | | |
35 | | - | |
| 24 | + | |
| 25 | + | |
36 | 26 | | |
37 | 27 | | |
38 | | - | |
39 | | - | |
40 | | - | |
41 | | - | |
42 | | - | |
43 | | - | |
44 | | - | |
45 | | - | |
46 | | - | |
47 | | - | |
48 | | - | |
49 | | - | |
50 | | - | |
51 | | - | |
52 | | - | |
53 | | - | |
54 | | - | |
55 | | - | |
56 | | - | |
57 | | - | |
58 | | - | |
59 | 28 | | |
60 | | - | |
61 | | - | |
| 29 | + | |
| 30 | + | |
62 | 31 | | |
63 | 32 | | |
64 | 33 | | |
65 | 34 | | |
66 | | - | |
| 35 | + | |
67 | 36 | | |
68 | | - | |
69 | | - | |
70 | | - | |
71 | | - | |
72 | | - | |
73 | | - | |
74 | | - | |
75 | | - | |
76 | | - | |
77 | | - | |
78 | | - | |
79 | | - | |
80 | | - | |
81 | | - | |
82 | | - | |
83 | | - | |
84 | | - | |
85 | | - | |
86 | | - | |
87 | | - | |
88 | | - | |
89 | | - | |
90 | | - | |
91 | | - | |
92 | | - | |
93 | | - | |
94 | | - | |
95 | | - | |
96 | | - | |
97 | | - | |
| 37 | + | |
98 | 38 | | |
99 | 39 | | |
100 | | - | |
101 | | - | |
102 | | - | |
103 | | - | |
104 | | - | |
105 | | - | |
106 | | - | |
| 40 | + | |
| 41 | + | |
107 | 42 | | |
108 | 43 | | |
109 | | - | |
| 44 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
21 | 21 | | |
22 | 22 | | |
23 | 23 | | |
24 | | - | |
| 24 | + | |
25 | 25 | | |
26 | | - | |
27 | | - | |
28 | | - | |
| 26 | + | |
29 | 27 | | |
30 | 28 | | |
31 | | - | |
32 | | - | |
33 | | - | |
34 | | - | |
35 | | - | |
36 | | - | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
37 | 35 | | |
38 | 36 | | |
39 | 37 | | |
| |||
46 | 44 | | |
47 | 45 | | |
48 | 46 | | |
49 | | - | |
| 47 | + | |
50 | 48 | | |
51 | 49 | | |
52 | 50 | | |
| |||
0 commit comments