This repository contains the implementation and experiments for exploring E(3)-equivariance in Graph Neural Networks (GNNs). It demonstrates how incorporating geometric symmetry in GNNs can improve their ability to model molecular properties using the QM9 dataset.
This project illustrates findings from Hoogeboom et al. in order to show the advantages of E(3)-equivariance in molecular property prediction tasks. Specifically, we:
- Use the QM9 dataset for a regression task predicting the molecular dipole moment.
- Compare standard GNNs with E(3)-equivariant GNNs.
- Demonstrate the benefits of leveraging geometric structure in molecular data.
The QM9 dataset contains small molecules with up to 29 atoms and features:
- Node features: Atom types, atomic number, aromaticity, hybridization, and more.
- Edge features: Bond types (single, double, triple, aromatic).
- Target properties: 19 physical properties, with our focus on the dipole moment.
For simplicity:
- We filter molecules with 12 atoms or fewer, resulting in 4005 molecules.
- The dataset is split into 80% train, 10% validation, and 10% test.
We evaluate the following models:
- LinReg: A simple baseline that ignores graph structure.
- MPNN: A standard Message Passing Neural Network.
- EGNN: An E(3)-equivariant GNN that incorporates spatial symmetry.
- EGNN_edge: A variant of EGNN that uses edge features.
All models are permutation invariant, and the EGNN-based models are E(3)-equivariant.
Our pipeline consists of:
- Preprocessing: Filtering and normalizing the QM9 dataset.
- Training:
- Models are trained for 500 epochs using the Adam optimizer and MSE loss.
- Learning rate scheduling with ReduceLROnPlateau.
- Validation: Evaluate models on unseen data to compare their generalization.
Frameworks:
- PyTorch Geometric for graph data.
- WandB for experiment tracking.
-
Clone the repository:
git clone https://github.com/your_username/EGGN_Classifier.git cd EGGN_Classifier -
Set up the Environment
conda create -n egnn-env python=3.8 conda activate egnn-env pip install -r requirements.txt
- Train Models
python main.py --model EGNN --epochs 500
Options for --model
LinRegMPNNEGNNEGNN_edge
- Evaluate Models
python eval.py --model EGNN
- Visualization
Results such as training/validation loss and learning rate are logged in WandB.
- LinReg: Performs poorly, while GNNs effectively minimize loss.
- EGNN-based models: Outperform MPNNs, with EGNN_edge achieving the best generalization.
- E(3)-equivariance: Leads to significant performance gains.
- Incorporating edge features improves model expressivity.
Figures and detailed discussions are in the report.
Contributions are welcome! Especially implementing an SE(3) equivariant model as it should produce better results. Please:
- Fork the repository.
- Create a new branch (
feature/new-feature). - Submit a pull request.
- Hoogeboom et al. for their foundational work on E(3)-equivariant GNNs.
- PyTorch Geometric for their framework for graph data.
- The QM9 dataset creators.