Skip to content

Commit 3344791

Browse files
authored
Triton neighbor list implementation (#373)
* cleaned up installation * keep just brute tests * add triton dep * first triton and pytorch implementations * fix assertion error * fixed last issue * less computations * reorganized code. added first cell implementation * upd * fixed all tests except one * added working cell implementation * working with larger block_atoms * more efficient cell * update the benchmark suite * shared triton implementation * update issue in benchmark * cleanup * cleanup * fix benchmark * cell implementation closer to CUDA * use a while loop instead of breaking which doesn't work in triton * better printing * initial sorted cell list impl * fix benchmark printing * nearly working cell impl * wip * wip * different cell impl * one more cell implementation * cuda graph comp * tiled version * another impl * faster version * memory coalesced cell neighbor impl * cleanup and keep just the last cell version * removing shared memory implementation * cleanup and file headers * remove CUDA implementations * missing function * fix for torch script * updating installation isntructions * making triton optional * install different triton package on windows and none on OSX * simplify CI deployment and testing * try without lock file * fix for flake? * fix python version * don't use cuda on ARM machines * don't try except in compilable code * no triton on aarch64 * cannot use delayed imports with torchscript * add ase as a dep * unfreeze torch version * fix the OSX issue with MPS not supporting float64 * added test for scripting, then compiling * fix cuda graphing of torchscripted models. update tests * restore script+compile test * get rid of setup_for_compile_cudagraphs * fix test warnings * undo some changes to benchmarks * rename caffeine * calculators should warmup before recompiling * catch in output_modules also the case where we are compiling * int32 dtype for neighbor list and num_pairs * added test for ASE calculator * no need to trigger compilation anymore * no need to trigger compilation * skip cuda test if no cuda available * skip on windows due to missing compiler * prevent triton recompilation with changing number of atoms and cutoffs * use triton_wrap for compatibility with more pytorch features * make scatter compilable, make box a registered buffer of OptimizedDistance * fix backwards compatibility * remove constraint inserted for exporting * undo * revert change to scatter * cleanup * optimized the pytorch brute neighborlist implementation to not do O(n^2) but O(n^2/2) computations and mem usage * simplify * changing the neighbor arrays from torch.int32 to torch.long had a significant performance boost
1 parent 75b16e6 commit 3344791

36 files changed

+1498
-1993
lines changed

.github/workflows/publish.yml

Lines changed: 12 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -6,166 +6,51 @@ on:
66

77
jobs:
88
build:
9-
name: Build wheels on ${{ matrix.os }}-${{ matrix.accelerator }}
10-
runs-on: ${{ matrix.os }}
11-
strategy:
12-
fail-fast: false
13-
matrix:
14-
os: [ubuntu-latest, ubuntu-24.04-arm, windows-2022, macos-latest]
15-
accelerator: [cpu, cu118, cu126] #, cu128]
16-
exclude:
17-
- os: ubuntu-24.04-arm
18-
accelerator: cu118
19-
- os: ubuntu-24.04-arm
20-
accelerator: cu126
21-
# - os: ubuntu-24.04-arm
22-
# accelerator: cu128
23-
- os: macos-latest
24-
accelerator: cu118
25-
- os: macos-latest
26-
accelerator: cu126
27-
# - os: macos-latest
28-
# accelerator: cu128
9+
name: Create source distribution
10+
runs-on: ubuntu-latest
2911

3012
steps:
31-
- name: Free space of Github Runner (otherwise it will fail by running out of disk)
32-
if: matrix.os == 'ubuntu-latest'
33-
run: |
34-
sudo rm -rf /usr/share/dotnet
35-
sudo rm -rf /opt/ghc
36-
sudo rm -rf "/usr/local/share/boost"
37-
sudo rm -rf "/usr/local/.ghcup"
38-
sudo rm -rf "/usr/local/julia1.9.2"
39-
sudo rm -rf "/usr/local/lib/android"
40-
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
41-
42-
- uses: actions/checkout@v4
13+
- uses: actions/checkout@v5
4314

44-
- uses: actions/setup-python@v5
15+
- uses: actions/setup-python@v6
4516
with:
4617
python-version: "3.13"
4718

4819
- name: Install cibuildwheel
49-
run: python -m pip install cibuildwheel==3.1.3
50-
51-
- name: Activate MSVC
52-
uses: ilammy/msvc-dev-cmd@v1
53-
with:
54-
toolset: 14.29
55-
if: matrix.os == 'windows-2022'
20+
run: pip install build
5621

57-
- name: Build wheels
58-
if: matrix.os != 'windows-2022'
59-
shell: bash
60-
run: python -m cibuildwheel --output-dir wheelhouse
61-
env:
62-
ACCELERATOR: ${{ matrix.accelerator }}
63-
CPU_TRAIN: ${{ runner.os == 'macOS' && 'true' || 'false' }}
64-
65-
- name: Build wheels
66-
if: matrix.os == 'windows-2022'
67-
shell: cmd # Use cmd on Windows to avoid bash environment taking priority over MSVC variables
68-
run: python -m cibuildwheel --output-dir wheelhouse
69-
env:
70-
ACCELERATOR: ${{ matrix.accelerator }}
71-
DISTUTILS_USE_SDK: "1" # Windows requires this to use vc for building
72-
SKIP_TORCH_COMPILE: "true"
22+
- name: Build pypi package
23+
run: python -m build --sdist
7324

7425
- uses: actions/upload-artifact@v4
7526
with:
76-
name: ${{ matrix.accelerator }}-cibw-wheels-${{ matrix.os }}-${{ strategy.job-index }}
77-
path: ./wheelhouse/*.whl
27+
name: source_dist
28+
path: dist/*.tar.gz
7829

79-
publish-to-public-pypi:
30+
publish-to-pypi:
8031
name: >-
8132
Publish Python 🐍 distribution 📦 to PyPI
8233
needs:
8334
- build
8435
runs-on: ubuntu-latest
8536
environment:
8637
name: pypi
38+
url: https://pypi.org/p/torchmd-net
8739
permissions:
8840
id-token: write # IMPORTANT: mandatory for trusted publishing
89-
strategy:
90-
fail-fast: false
91-
matrix:
92-
accelerator: [cpu, cu118, cu126] #, hip, cu124, cu126, cu128]
9341

9442
steps:
9543
- name: Download all the dists
9644
uses: actions/download-artifact@v4
9745
with:
98-
pattern: "${{ matrix.accelerator }}-cibw-wheels*"
9946
path: dist/
10047
merge-multiple: true
10148

10249
- name: Publish distribution 📦 to PyPI
10350
uses: pypa/gh-action-pypi-publish@release/v1
10451
with:
10552
password: ${{ secrets.TMDNET_PYPI_API_TOKEN }}
106-
skip-existing: true
107-
108-
# publish-to-accelera-pypi:
109-
# name: >-
110-
# Publish Python 🐍 distribution 📦 to Acellera PyPI
111-
# needs:
112-
# - build
113-
# runs-on: ubuntu-latest
114-
# permissions: # Needed for GCP authentication
115-
# contents: "read"
116-
# id-token: "write"
117-
# strategy:
118-
# fail-fast: false
119-
# matrix:
120-
# accelerator: [cpu, cu118, cu126, cu128]
121-
122-
# steps:
123-
# - uses: actions/checkout@v4 # Needed for GCP authentication for some reason
124-
125-
# - name: Set up Cloud SDK
126-
# uses: google-github-actions/auth@v2
127-
# with:
128-
# workload_identity_provider: ${{ secrets.GCP_WORKLOAD_IDENTITY_PROVIDER }}
129-
# service_account: ${{ secrets.GCP_PYPI_SERVICE_ACCOUNT }}
130-
131-
# - name: Download all the dists
132-
# uses: actions/download-artifact@v4
133-
# with:
134-
# pattern: "${{ matrix.accelerator }}-cibw-wheels*"
135-
# path: dist/
136-
# merge-multiple: true
137-
138-
# - name: Publish distribution 📦 to Acellera PyPI
139-
# run: |
140-
# pip install build twine keyring keyrings.google-artifactregistry-auth
141-
# pip install -U packaging
142-
# twine upload --repository-url https://us-central1-python.pkg.dev/pypi-packages-455608/${{ matrix.accelerator }} dist/* --verbose --skip-existing
143-
144-
# publish-to-official-pypi:
145-
# name: >-
146-
# Publish Python 🐍 distribution 📦 to PyPI
147-
# needs:
148-
# - build
149-
# runs-on: ubuntu-latest
150-
# environment:
151-
# name: pypi
152-
# url: https://pypi.org/p/torchmd-net
153-
# permissions:
154-
# id-token: write # IMPORTANT: mandatory for trusted publishing
155-
156-
# steps:
157-
# - name: Download all the dists
158-
# uses: actions/download-artifact@v4
159-
# with:
160-
# pattern: "cu118-cibw-wheels*"
161-
# path: dist/
162-
# merge-multiple: true
163-
164-
# - name: Publish distribution 📦 to PyPI
165-
# uses: pypa/gh-action-pypi-publish@release/v1
166-
# with:
167-
# password: ${{ secrets.TMDNET_PYPI_API_TOKEN }}
168-
# skip_existing: true
53+
skip_existing: true
16954

17055
github-release:
17156
name: >-

.github/workflows/test.yml

Lines changed: 10 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -17,93 +17,28 @@ jobs:
1717
["ubuntu-latest", "ubuntu-22.04-arm", "macos-latest", "windows-2022"]
1818
python-version: ["3.13"]
1919

20-
defaults: # Needed for conda
21-
run:
22-
shell: bash -l {0}
23-
2420
steps:
2521
- name: Check out
26-
uses: actions/checkout@v4
27-
28-
- uses: conda-incubator/setup-miniconda@v3
29-
with:
30-
python-version: ${{ matrix.python-version }}
31-
channels: conda-forge
32-
conda-remove-defaults: "true"
33-
if: matrix.os != 'macos-13'
22+
uses: actions/checkout@v5
3423

35-
- uses: conda-incubator/setup-miniconda@v3
24+
- name: Install uv
25+
uses: astral-sh/setup-uv@v7
3626
with:
3727
python-version: ${{ matrix.python-version }}
38-
channels: conda-forge
39-
mamba-version: "*"
40-
conda-remove-defaults: "true"
41-
if: matrix.os == 'macos-13'
42-
43-
- name: Install OS-specific compilers
44-
run: |
45-
if [[ "${{ matrix.os }}" == "ubuntu-22.04-arm" ]]; then
46-
conda install gxx --channel conda-forge --override-channels
47-
elif [[ "${{ runner.os }}" == "Linux" ]]; then
48-
conda install gxx --channel conda-forge --override-channels
49-
elif [[ "${{ runner.os }}" == "macOS" ]]; then
50-
conda install llvm-openmp pybind11 --channel conda-forge --override-channels
51-
echo "CC=clang" >> $GITHUB_ENV
52-
echo "CXX=clang++" >> $GITHUB_ENV
53-
elif [[ "${{ runner.os }}" == "Windows" ]]; then
54-
conda install vc vc14_runtime vs2015_runtime --channel conda-forge --override-channels
55-
fi
56-
57-
- name: List the conda environment
58-
run: conda list
5928

60-
- name: Install testing packages
61-
run: conda install -y -c conda-forge flake8 pytest psutil python-build
29+
- name: Install the project
30+
run: uv sync --all-extras --dev
6231

6332
- name: Lint with flake8
6433
run: |
6534
# stop the build if there are Python syntax errors or undefined names
66-
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
35+
uv run flake8 ./torchmdnet --count --select=E9,F63,F7,F82 --show-source --statistics
6736
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
68-
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
69-
70-
- name: Set pytorch index
71-
run: |
72-
if [[ "${{ runner.os }}" == "Windows" ]]; then
73-
mkdir -p "C:\ProgramData\pip"
74-
echo "[global]
75-
extra-index-url = https://download.pytorch.org/whl/cpu" > "C:\ProgramData\pip\pip.ini"
76-
else
77-
mkdir -p $HOME/.config/pip
78-
echo "[global]
79-
extra-index-url = https://download.pytorch.org/whl/cpu" > $HOME/.config/pip/pip.conf
80-
fi
81-
82-
- name: Build and install the package
83-
run: |
84-
if [[ "${{ runner.os }}" == "Windows" ]]; then
85-
export LIB="C:/Miniconda/envs/test/Library/lib"
86-
fi
87-
python -m build
88-
pip install dist/*.whl
89-
env:
90-
ACCELERATOR: "cpu"
91-
92-
# - name: Install nnpops
93-
# if: matrix.os == 'ubuntu-latest' || matrix.os == 'macos-latest'
94-
# run: conda install nnpops --channel conda-forge --override-channels
95-
96-
- name: List the conda environment
97-
run: conda list
37+
uv run flake8 ./torchmdnet --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
9838
9939
- name: Run tests
100-
run: pytest -v -s --durations=10
101-
env:
102-
ACCELERATOR: "cpu"
103-
SKIP_TORCH_COMPILE: ${{ runner.os == 'Windows' && 'true' || 'false' }}
104-
OMP_PREFIX: ${{ runner.os == 'macOS' && '/Users/runner/miniconda3/envs/test' || '' }}
105-
CPU_TRAIN: ${{ runner.os == 'macOS' && 'true' || 'false' }}
106-
LONG_TRAIN: "true"
40+
# For example, using `pytest`
41+
run: uv run pytest tests
10742

10843
- name: Test torchmd-train utility
109-
run: torchmd-train --help
44+
run: uv run torchmd-train --help

README.md

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,17 @@ Documentation is available at https://torchmd-net.readthedocs.io
2121

2222

2323
## Installation
24-
TorchMD-Net is available as a pip installable wheel as well as in [conda-forge](https://conda-forge.org/)
24+
TorchMD-Net is available as a pip package as well as in [conda-forge](https://conda-forge.org/)
2525

26-
TorchMD-Net provides builds for CPU-only, CUDA 11 and CUDA 12. CPU versions are only provided as reference,
27-
as the performance will be extremely limited.
28-
Depending on which variant you wish to install, you can install it with one of the following commands:
26+
As TorchMD-Net depends on PyTorch we need to add additional index URLs to the installation command as per [pytorch](https://pytorch.org/get-started/locally/)
2927

3028
```sh
31-
# The following will install the CUDA 11.8 version
32-
pip install torchmd-net-cu11 --extra-index-url https://download.pytorch.org/whl/cu118
33-
# The following will install the CUDA 12.4 version
34-
pip install torchmd-net-cu12 --extra-index-url https://download.pytorch.org/whl/cu124
35-
# The following will install the CPU only version (not recommended)
36-
pip install torchmd-net-cpu --extra-index-url https://download.pytorch.org/whl/cpu
29+
# The following will install TorchMD-Net with PyTorch CUDA 11.8 version
30+
pip install torchmd-net --extra-index-url https://download.pytorch.org/whl/cu118
31+
# The following will install TorchMD-Net with PyTorch CUDA 12.4 version
32+
pip install torchmd-net --extra-index-url https://download.pytorch.org/whl/cu124
33+
# The following will install TorchMD-Net with PyTorch CPU only version (not recommended)
34+
pip install torchmd-net --extra-index-url https://download.pytorch.org/whl/cpu
3735
```
3836

3937
Alternatively it can be installed with conda or mamba with one of the following commands.
@@ -46,7 +44,7 @@ mamba install torchmd-net cuda-version=12.4
4644

4745
### Install from source
4846

49-
TorchMD-Net is installed using pip, but you will need to install some dependencies before. Check [this documentation page](https://torchmd-net.readthedocs.io/en/latest/installation.html#install-from-source).
47+
TorchMD-Net is installed using pip with `pip install -e .` to create an editable install.
5048

5149
## Usage
5250
Specifying training arguments can either be done via a configuration yaml file or through command line arguments directly. Several examples of architectural and training specifications for some models and datasets can be found in [examples/](https://github.com/torchmd/torchmd-net/tree/main/examples). Note that if a parameter is present both in the yaml file and the command line, the command line version takes precedence.

0 commit comments

Comments
 (0)