-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathkaggle_dataset_graph_network
More file actions
1 lines (1 loc) · 47.1 KB
/
kaggle_dataset_graph_network
File metadata and controls
1 lines (1 loc) · 47.1 KB
1
{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.10.13","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"gpu","dataSources":[{"sourceId":59093,"databundleVersionId":7469972,"sourceType":"competition"}],"dockerImageVersionId":30648,"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"markdown","source":"# GNNs meets HMS","metadata":{}},{"cell_type":"code","source":"import os\n\nimport numpy as np\nimport pandas as pd\n\nimport torch\nfrom torch.utils.data import Dataset\nfrom torch.optim.lr_scheduler import CosineAnnealingLR\n\n!pip install torch_geometric\nimport torch.nn.functional as F\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import SGConv, ChebConv, GCNConv, GATConv\nfrom torch_geometric.nn.pool import global_mean_pool\n\n!pip install torcheeg\nfrom torcheeg.transforms import Compose, BandDifferentialEntropy, MeanStdNormalize\nfrom torcheeg.transforms.pyg import ToDynamicG\n\nfrom sklearn.model_selection import KFold\nfrom sklearn.preprocessing import StandardScaler\n\nimport matplotlib.pyplot as plt","metadata":{"execution":{"iopub.status.busy":"2024-05-07T10:57:42.175186Z","iopub.execute_input":"2024-05-07T10:57:42.175956Z","iopub.status.idle":"2024-05-07T10:58:29.262422Z","shell.execute_reply.started":"2024-05-07T10:57:42.175921Z","shell.execute_reply":"2024-05-07T10:58:29.261544Z"},"trusted":true},"execution_count":1,"outputs":[{"name":"stdout","text":"Collecting torch_geometric\n Downloading torch_geometric-2.5.3-py3-none-any.whl.metadata (64 kB)\n\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m64.2/64.2 kB\u001b[0m \u001b[31m1.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n\u001b[?25hRequirement already satisfied: tqdm in /opt/conda/lib/python3.10/site-packages (from torch_geometric) (4.66.1)\nRequirement already satisfied: numpy in /opt/conda/lib/python3.10/site-packages (from torch_geometric) (1.24.4)\nRequirement already satisfied: scipy in /opt/conda/lib/python3.10/site-packages (from torch_geometric) (1.11.4)\nRequirement already satisfied: fsspec in /opt/conda/lib/python3.10/site-packages (from torch_geometric) (2023.12.2)\nRequirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch_geometric) (3.1.2)\nRequirement already satisfied: aiohttp in /opt/conda/lib/python3.10/site-packages (from torch_geometric) (3.9.1)\nRequirement already satisfied: requests in /opt/conda/lib/python3.10/site-packages (from torch_geometric) (2.31.0)\nRequirement already satisfied: pyparsing in /opt/conda/lib/python3.10/site-packages (from torch_geometric) (3.1.1)\nRequirement already satisfied: scikit-learn in /opt/conda/lib/python3.10/site-packages (from torch_geometric) (1.2.2)\nRequirement already satisfied: psutil>=5.8.0 in /opt/conda/lib/python3.10/site-packages (from torch_geometric) (5.9.3)\nRequirement already satisfied: attrs>=17.3.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->torch_geometric) (23.2.0)\nRequirement already satisfied: multidict<7.0,>=4.5 in /opt/conda/lib/python3.10/site-packages (from aiohttp->torch_geometric) (6.0.4)\nRequirement already satisfied: yarl<2.0,>=1.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->torch_geometric) (1.9.3)\nRequirement already satisfied: frozenlist>=1.1.1 in /opt/conda/lib/python3.10/site-packages (from aiohttp->torch_geometric) (1.4.1)\nRequirement already satisfied: aiosignal>=1.1.2 in /opt/conda/lib/python3.10/site-packages (from aiohttp->torch_geometric) (1.3.1)\nRequirement already satisfied: async-timeout<5.0,>=4.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->torch_geometric) (4.0.3)\nRequirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch_geometric) (2.1.3)\nRequirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests->torch_geometric) (3.3.2)\nRequirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests->torch_geometric) (3.6)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests->torch_geometric) (1.26.18)\nRequirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests->torch_geometric) (2023.11.17)\nRequirement already satisfied: joblib>=1.1.1 in /opt/conda/lib/python3.10/site-packages (from scikit-learn->torch_geometric) (1.3.2)\nRequirement already satisfied: threadpoolctl>=2.0.0 in /opt/conda/lib/python3.10/site-packages (from scikit-learn->torch_geometric) (3.2.0)\nDownloading torch_geometric-2.5.3-py3-none-any.whl (1.1 MB)\n\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m16.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n\u001b[?25hInstalling collected packages: torch_geometric\nSuccessfully installed torch_geometric-2.5.3\nCollecting torcheeg\n Downloading torcheeg-1.1.2.tar.gz (214 kB)\n\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m214.5/214.5 kB\u001b[0m \u001b[31m4.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25ldone\n\u001b[?25hRequirement already satisfied: tqdm>=4.64.0 in /opt/conda/lib/python3.10/site-packages (from torcheeg) (4.66.1)\nRequirement already satisfied: numpy>=1.21.5 in /opt/conda/lib/python3.10/site-packages (from torcheeg) (1.24.4)\nRequirement already satisfied: pandas>=1.3.5 in /opt/conda/lib/python3.10/site-packages (from torcheeg) (2.1.4)\nCollecting xlrd>=2.0.1 (from torcheeg)\n Downloading xlrd-2.0.1-py2.py3-none-any.whl.metadata (3.4 kB)\nRequirement already satisfied: scipy>=1.7.3 in /opt/conda/lib/python3.10/site-packages (from torcheeg) (1.11.4)\nRequirement already satisfied: scikit-learn>=1.0.2 in /opt/conda/lib/python3.10/site-packages (from torcheeg) (1.2.2)\nCollecting lmdb>=1.3.0 (from torcheeg)\n Downloading lmdb-1.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.2 kB)\nCollecting einops>=0.4.1 (from torcheeg)\n Downloading einops-0.8.0-py3-none-any.whl.metadata (12 kB)\nRequirement already satisfied: mne>=1.0.3 in /opt/conda/lib/python3.10/site-packages (from torcheeg) (1.6.1)\nCollecting xmltodict>=0.13.0 (from torcheeg)\n Downloading xmltodict-0.13.0-py2.py3-none-any.whl.metadata (7.7 kB)\nRequirement already satisfied: networkx>=2.6.3 in /opt/conda/lib/python3.10/site-packages (from torcheeg) (3.2.1)\nRequirement already satisfied: PyWavelets>=1.3.0 in /opt/conda/lib/python3.10/site-packages (from torcheeg) (1.5.0)\nCollecting spectrum>=0.8.1 (from torcheeg)\n Downloading spectrum-0.8.1.tar.gz (230 kB)\n\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m230.8/230.8 kB\u001b[0m \u001b[31m13.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25ldone\n\u001b[?25hRequirement already satisfied: torchmetrics>=0.10.0 in /opt/conda/lib/python3.10/site-packages (from torcheeg) (1.3.0.post0)\nCollecting mne_connectivity>=0.4.0 (from torcheeg)\n Downloading mne_connectivity-0.6.0-py3-none-any.whl.metadata (10 kB)\nRequirement already satisfied: pytorch-lightning>=1.9.5 in /opt/conda/lib/python3.10/site-packages (from torcheeg) (2.1.3)\nRequirement already satisfied: matplotlib>=3.5.0 in /opt/conda/lib/python3.10/site-packages (from mne>=1.0.3->torcheeg) (3.7.4)\nRequirement already satisfied: pooch>=1.5 in /opt/conda/lib/python3.10/site-packages (from mne>=1.0.3->torcheeg) (1.8.0)\nRequirement already satisfied: decorator in /opt/conda/lib/python3.10/site-packages (from mne>=1.0.3->torcheeg) (5.1.1)\nRequirement already satisfied: packaging in /opt/conda/lib/python3.10/site-packages (from mne>=1.0.3->torcheeg) (21.3)\nRequirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from mne>=1.0.3->torcheeg) (3.1.2)\nRequirement already satisfied: lazy-loader>=0.3 in /opt/conda/lib/python3.10/site-packages (from mne>=1.0.3->torcheeg) (0.3)\nCollecting netCDF4>=1.6.5 (from mne_connectivity>=0.4.0->torcheeg)\n Downloading netCDF4-1.6.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.8 kB)\nRequirement already satisfied: xarray>=2023.11.0 in /opt/conda/lib/python3.10/site-packages (from mne_connectivity>=0.4.0->torcheeg) (2024.1.0)\nRequirement already satisfied: python-dateutil>=2.8.2 in /opt/conda/lib/python3.10/site-packages (from pandas>=1.3.5->torcheeg) (2.8.2)\nRequirement already satisfied: pytz>=2020.1 in /opt/conda/lib/python3.10/site-packages (from pandas>=1.3.5->torcheeg) (2023.3.post1)\nRequirement already satisfied: tzdata>=2022.1 in /opt/conda/lib/python3.10/site-packages (from pandas>=1.3.5->torcheeg) (2023.4)\nRequirement already satisfied: torch>=1.12.0 in /opt/conda/lib/python3.10/site-packages (from pytorch-lightning>=1.9.5->torcheeg) (2.1.2)\nRequirement already satisfied: PyYAML>=5.4 in /opt/conda/lib/python3.10/site-packages (from pytorch-lightning>=1.9.5->torcheeg) (6.0.1)\nRequirement already satisfied: fsspec>=2022.5.0 in /opt/conda/lib/python3.10/site-packages (from fsspec[http]>=2022.5.0->pytorch-lightning>=1.9.5->torcheeg) (2023.12.2)\nRequirement already satisfied: typing-extensions>=4.0.0 in /opt/conda/lib/python3.10/site-packages (from pytorch-lightning>=1.9.5->torcheeg) (4.9.0)\nRequirement already satisfied: lightning-utilities>=0.8.0 in /opt/conda/lib/python3.10/site-packages (from pytorch-lightning>=1.9.5->torcheeg) (0.10.1)\nRequirement already satisfied: joblib>=1.1.1 in /opt/conda/lib/python3.10/site-packages (from scikit-learn>=1.0.2->torcheeg) (1.3.2)\nRequirement already satisfied: threadpoolctl>=2.0.0 in /opt/conda/lib/python3.10/site-packages (from scikit-learn>=1.0.2->torcheeg) (3.2.0)\nCollecting easydev (from spectrum>=0.8.1->torcheeg)\n Downloading easydev-0.13.2-py3-none-any.whl.metadata (3.5 kB)\nRequirement already satisfied: requests in /opt/conda/lib/python3.10/site-packages (from fsspec[http]>=2022.5.0->pytorch-lightning>=1.9.5->torcheeg) (2.31.0)\nRequirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /opt/conda/lib/python3.10/site-packages (from fsspec[http]>=2022.5.0->pytorch-lightning>=1.9.5->torcheeg) (3.9.1)\nRequirement already satisfied: setuptools in /opt/conda/lib/python3.10/site-packages (from lightning-utilities>=0.8.0->pytorch-lightning>=1.9.5->torcheeg) (69.0.3)\nRequirement already satisfied: contourpy>=1.0.1 in /opt/conda/lib/python3.10/site-packages (from matplotlib>=3.5.0->mne>=1.0.3->torcheeg) (1.2.0)\nRequirement already satisfied: cycler>=0.10 in /opt/conda/lib/python3.10/site-packages (from matplotlib>=3.5.0->mne>=1.0.3->torcheeg) (0.12.1)\nRequirement already satisfied: fonttools>=4.22.0 in /opt/conda/lib/python3.10/site-packages (from matplotlib>=3.5.0->mne>=1.0.3->torcheeg) (4.47.0)\nRequirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/lib/python3.10/site-packages (from matplotlib>=3.5.0->mne>=1.0.3->torcheeg) (1.4.5)\nRequirement already satisfied: pillow>=6.2.0 in /opt/conda/lib/python3.10/site-packages (from matplotlib>=3.5.0->mne>=1.0.3->torcheeg) (9.5.0)\nRequirement already satisfied: pyparsing>=2.3.1 in /opt/conda/lib/python3.10/site-packages (from matplotlib>=3.5.0->mne>=1.0.3->torcheeg) (3.1.1)\nCollecting cftime (from netCDF4>=1.6.5->mne_connectivity>=0.4.0->torcheeg)\n Downloading cftime-1.6.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.6 kB)\nRequirement already satisfied: certifi in /opt/conda/lib/python3.10/site-packages (from netCDF4>=1.6.5->mne_connectivity>=0.4.0->torcheeg) (2023.11.17)\nRequirement already satisfied: platformdirs>=2.5.0 in /opt/conda/lib/python3.10/site-packages (from pooch>=1.5->mne>=1.0.3->torcheeg) (4.1.0)\nRequirement already satisfied: six>=1.5 in /opt/conda/lib/python3.10/site-packages (from python-dateutil>=2.8.2->pandas>=1.3.5->torcheeg) (1.16.0)\nRequirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from torch>=1.12.0->pytorch-lightning>=1.9.5->torcheeg) (3.13.1)\nRequirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch>=1.12.0->pytorch-lightning>=1.9.5->torcheeg) (1.12)\nCollecting packaging (from mne>=1.0.3->torcheeg)\n Downloading packaging-24.0-py3-none-any.whl.metadata (3.2 kB)\nRequirement already satisfied: colorama<0.5.0,>=0.4.6 in /opt/conda/lib/python3.10/site-packages (from easydev->spectrum>=0.8.1->torcheeg) (0.4.6)\nCollecting colorlog<7.0.0,>=6.8.2 (from easydev->spectrum>=0.8.1->torcheeg)\n Downloading colorlog-6.8.2-py3-none-any.whl.metadata (10 kB)\nRequirement already satisfied: line-profiler<5.0.0,>=4.1.2 in /opt/conda/lib/python3.10/site-packages (from easydev->spectrum>=0.8.1->torcheeg) (4.1.2)\nCollecting pexpect<5.0.0,>=4.9.0 (from easydev->spectrum>=0.8.1->torcheeg)\n Downloading pexpect-4.9.0-py2.py3-none-any.whl.metadata (2.5 kB)\nCollecting platformdirs>=2.5.0 (from pooch>=1.5->mne>=1.0.3->torcheeg)\n Downloading platformdirs-4.2.1-py3-none-any.whl.metadata (11 kB)\nRequirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->mne>=1.0.3->torcheeg) (2.1.3)\nRequirement already satisfied: attrs>=17.3.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch-lightning>=1.9.5->torcheeg) (23.2.0)\nRequirement already satisfied: multidict<7.0,>=4.5 in /opt/conda/lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch-lightning>=1.9.5->torcheeg) (6.0.4)\nRequirement already satisfied: yarl<2.0,>=1.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch-lightning>=1.9.5->torcheeg) (1.9.3)\nRequirement already satisfied: frozenlist>=1.1.1 in /opt/conda/lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch-lightning>=1.9.5->torcheeg) (1.4.1)\nRequirement already satisfied: aiosignal>=1.1.2 in /opt/conda/lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch-lightning>=1.9.5->torcheeg) (1.3.1)\nRequirement already satisfied: async-timeout<5.0,>=4.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch-lightning>=1.9.5->torcheeg) (4.0.3)\nRequirement already satisfied: ptyprocess>=0.5 in /opt/conda/lib/python3.10/site-packages (from pexpect<5.0.0,>=4.9.0->easydev->spectrum>=0.8.1->torcheeg) (0.7.0)\nRequirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests->fsspec[http]>=2022.5.0->pytorch-lightning>=1.9.5->torcheeg) (3.3.2)\nRequirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests->fsspec[http]>=2022.5.0->pytorch-lightning>=1.9.5->torcheeg) (3.6)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests->fsspec[http]>=2022.5.0->pytorch-lightning>=1.9.5->torcheeg) (1.26.18)\nRequirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.10/site-packages (from sympy->torch>=1.12.0->pytorch-lightning>=1.9.5->torcheeg) (1.3.0)\nDownloading einops-0.8.0-py3-none-any.whl (43 kB)\n\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m43.2/43.2 kB\u001b[0m \u001b[31m2.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n\u001b[?25hDownloading lmdb-1.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (299 kB)\n\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m299.2/299.2 kB\u001b[0m \u001b[31m17.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n\u001b[?25hDownloading mne_connectivity-0.6.0-py3-none-any.whl (107 kB)\n\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m107.2/107.2 kB\u001b[0m \u001b[31m6.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n\u001b[?25hDownloading xlrd-2.0.1-py2.py3-none-any.whl (96 kB)\n\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m96.5/96.5 kB\u001b[0m \u001b[31m6.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n\u001b[?25hDownloading xmltodict-0.13.0-py2.py3-none-any.whl (10.0 kB)\nDownloading netCDF4-1.6.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.5 MB)\n\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.5/5.5 MB\u001b[0m \u001b[31m70.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m00:01\u001b[0m\n\u001b[?25hDownloading packaging-24.0-py3-none-any.whl (53 kB)\n\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m53.5/53.5 kB\u001b[0m \u001b[31m3.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n\u001b[?25hDownloading easydev-0.13.2-py3-none-any.whl (56 kB)\n\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.8/56.8 kB\u001b[0m \u001b[31m3.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n\u001b[?25hDownloading colorlog-6.8.2-py3-none-any.whl (11 kB)\nDownloading pexpect-4.9.0-py2.py3-none-any.whl (63 kB)\n\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m63.8/63.8 kB\u001b[0m \u001b[31m4.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n\u001b[?25hDownloading platformdirs-4.2.1-py3-none-any.whl (17 kB)\nDownloading cftime-1.6.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m45.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n\u001b[?25hBuilding wheels for collected packages: torcheeg, spectrum\n Building wheel for torcheeg (setup.py) ... \u001b[?25ldone\n\u001b[?25h Created wheel for torcheeg: filename=torcheeg-1.1.2-py3-none-any.whl size=414907 sha256=d50a166ce8747ba05710ac5f44fda0bc772dc734e018e3ed4164c15b5b7bdba7\n Stored in directory: /root/.cache/pip/wheels/4b/9d/45/dc74737527e063f030217eda1f25791f98f06cc0bc6f013e4c\n Building wheel for spectrum (setup.py) ... \u001b[?25ldone\n\u001b[?25h Created wheel for spectrum: filename=spectrum-0.8.1-cp310-cp310-linux_x86_64.whl size=227335 sha256=c7a2ee59e6398313c195db2c26892afcfc6426f9efc3867ea45a0231e2c0feb4\n Stored in directory: /root/.cache/pip/wheels/e7/5a/09/ffc6afdf8a5a6f58e9851292108df32bb11374e11b8705cabd\nSuccessfully built torcheeg spectrum\nInstalling collected packages: lmdb, xmltodict, xlrd, platformdirs, pexpect, packaging, einops, colorlog, cftime, netCDF4, easydev, spectrum, mne_connectivity, torcheeg\n Attempting uninstall: platformdirs\n Found existing installation: platformdirs 4.1.0\n Uninstalling platformdirs-4.1.0:\n Successfully uninstalled platformdirs-4.1.0\n Attempting uninstall: pexpect\n Found existing installation: pexpect 4.8.0\n Uninstalling pexpect-4.8.0:\n Successfully uninstalled pexpect-4.8.0\n Attempting uninstall: packaging\n Found existing installation: packaging 21.3\n Uninstalling packaging-21.3:\n Successfully uninstalled packaging-21.3\n Attempting uninstall: colorlog\n Found existing installation: colorlog 6.8.0\n Uninstalling colorlog-6.8.0:\n Successfully uninstalled colorlog-6.8.0\n\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\ncudf 23.8.0 requires cubinlinker, which is not installed.\ncudf 23.8.0 requires cupy-cuda11x>=12.0.0, which is not installed.\ncudf 23.8.0 requires ptxcompiler, which is not installed.\ncuml 23.8.0 requires cupy-cuda11x>=12.0.0, which is not installed.\ndask-cudf 23.8.0 requires cupy-cuda11x>=12.0.0, which is not installed.\ncudf 23.8.0 requires cuda-python<12.0a0,>=11.7.1, but you have cuda-python 12.3.0 which is incompatible.\ncudf 23.8.0 requires pandas<1.6.0dev0,>=1.3, but you have pandas 2.1.4 which is incompatible.\ncudf 23.8.0 requires protobuf<5,>=4.21, but you have protobuf 3.20.3 which is incompatible.\ncuml 23.8.0 requires dask==2023.7.1, but you have dask 2024.1.0 which is incompatible.\ncuml 23.8.0 requires distributed==2023.7.1, but you have distributed 2024.1.0 which is incompatible.\ndask-cuda 23.8.0 requires dask==2023.7.1, but you have dask 2024.1.0 which is incompatible.\ndask-cuda 23.8.0 requires distributed==2023.7.1, but you have distributed 2024.1.0 which is incompatible.\ndask-cuda 23.8.0 requires pandas<1.6.0dev0,>=1.3, but you have pandas 2.1.4 which is incompatible.\ndask-cudf 23.8.0 requires dask==2023.7.1, but you have dask 2024.1.0 which is incompatible.\ndask-cudf 23.8.0 requires distributed==2023.7.1, but you have distributed 2024.1.0 which is incompatible.\ndask-cudf 23.8.0 requires pandas<1.6.0dev0,>=1.3, but you have pandas 2.1.4 which is incompatible.\ngoogle-cloud-bigquery 2.34.4 requires packaging<22.0dev,>=14.3, but you have packaging 24.0 which is incompatible.\njupyterlab 4.0.11 requires jupyter-lsp>=2.0.0, but you have jupyter-lsp 1.5.1 which is incompatible.\njupyterlab-lsp 5.0.2 requires jupyter-lsp>=2.0.0, but you have jupyter-lsp 1.5.1 which is incompatible.\nlibpysal 4.9.2 requires shapely>=2.0.1, but you have shapely 1.8.5.post1 which is incompatible.\nmomepy 0.7.0 requires shapely>=2, but you have shapely 1.8.5.post1 which is incompatible.\nosmnx 1.8.1 requires shapely>=2.0, but you have shapely 1.8.5.post1 which is incompatible.\nraft-dask 23.8.0 requires dask==2023.7.1, but you have dask 2024.1.0 which is incompatible.\nraft-dask 23.8.0 requires distributed==2023.7.1, but you have distributed 2024.1.0 which is incompatible.\nspopt 0.6.0 requires shapely>=2.0.1, but you have shapely 1.8.5.post1 which is incompatible.\ntensorflowjs 4.16.0 requires packaging~=23.1, but you have packaging 24.0 which is incompatible.\nvirtualenv 20.21.0 requires platformdirs<4,>=2.4, but you have platformdirs 4.2.1 which is incompatible.\u001b[0m\u001b[31m\n\u001b[0mSuccessfully installed cftime-1.6.3 colorlog-6.8.2 easydev-0.13.2 einops-0.8.0 lmdb-1.4.1 mne_connectivity-0.6.0 netCDF4-1.6.5 packaging-24.0 pexpect-4.9.0 platformdirs-4.2.1 spectrum-0.8.1 torcheeg-1.1.2 xlrd-2.0.1 xmltodict-0.13.0\n","output_type":"stream"}]},{"cell_type":"markdown","source":"## Define config","metadata":{}},{"cell_type":"code","source":"config = {\n 'batch_size': 2048,\n 'device': \"cuda\" if torch.cuda.is_available() else \"cpu\",\n 'n_fold': 3,\n 'seed': 2**20-1,\n 'epochs': 10,\n 'in_channels': 10_000, #4\n 'hidden_channels': 32,\n 'num_conv_layers': 3,\n 'num_classes': 6,\n 'lr': 1e-3,\n 'weight_decay': 5e-4\n}","metadata":{"execution":{"iopub.status.busy":"2024-05-07T10:59:42.602641Z","iopub.execute_input":"2024-05-07T10:59:42.603581Z","iopub.status.idle":"2024-05-07T10:59:42.608747Z","shell.execute_reply.started":"2024-05-07T10:59:42.603546Z","shell.execute_reply":"2024-05-07T10:59:42.607841Z"},"trusted":true},"execution_count":5,"outputs":[]},{"cell_type":"markdown","source":"## Get data","metadata":{}},{"cell_type":"code","source":"EEG_PATH_TRAIN = '/kaggle/input/hms-harmful-brain-activity-classification/train_eegs/'\n\ndf = pd.read_csv('/kaggle/input/hms-harmful-brain-activity-classification/train.csv')\ndf","metadata":{"execution":{"iopub.status.busy":"2024-05-07T10:59:46.112428Z","iopub.execute_input":"2024-05-07T10:59:46.112789Z","iopub.status.idle":"2024-05-07T10:59:46.390652Z","shell.execute_reply.started":"2024-05-07T10:59:46.112762Z","shell.execute_reply":"2024-05-07T10:59:46.389516Z"},"trusted":true},"execution_count":6,"outputs":[{"execution_count":6,"output_type":"execute_result","data":{"text/plain":" eeg_id eeg_sub_id eeg_label_offset_seconds spectrogram_id \\\n0 1628180742 0 0.0 353733 \n1 1628180742 1 6.0 353733 \n2 1628180742 2 8.0 353733 \n3 1628180742 3 18.0 353733 \n4 1628180742 4 24.0 353733 \n... ... ... ... ... \n106795 351917269 6 12.0 2147388374 \n106796 351917269 7 14.0 2147388374 \n106797 351917269 8 16.0 2147388374 \n106798 351917269 9 18.0 2147388374 \n106799 351917269 10 20.0 2147388374 \n\n spectrogram_sub_id spectrogram_label_offset_seconds label_id \\\n0 0 0.0 127492639 \n1 1 6.0 3887563113 \n2 2 8.0 1142670488 \n3 3 18.0 2718991173 \n4 4 24.0 3080632009 \n... ... ... ... \n106795 6 12.0 4195677307 \n106796 7 14.0 290896675 \n106797 8 16.0 461435451 \n106798 9 18.0 3786213131 \n106799 10 20.0 3642716176 \n\n patient_id expert_consensus seizure_vote lpd_vote gpd_vote \\\n0 42516 Seizure 3 0 0 \n1 42516 Seizure 3 0 0 \n2 42516 Seizure 3 0 0 \n3 42516 Seizure 3 0 0 \n4 42516 Seizure 3 0 0 \n... ... ... ... ... ... \n106795 10351 LRDA 0 0 0 \n106796 10351 LRDA 0 0 0 \n106797 10351 LRDA 0 0 0 \n106798 10351 LRDA 0 0 0 \n106799 10351 LRDA 0 0 0 \n\n lrda_vote grda_vote other_vote \n0 0 0 0 \n1 0 0 0 \n2 0 0 0 \n3 0 0 0 \n4 0 0 0 \n... ... ... ... \n106795 3 0 0 \n106796 3 0 0 \n106797 3 0 0 \n106798 3 0 0 \n106799 3 0 0 \n\n[106800 rows x 15 columns]","text/html":"<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>eeg_id</th>\n <th>eeg_sub_id</th>\n <th>eeg_label_offset_seconds</th>\n <th>spectrogram_id</th>\n <th>spectrogram_sub_id</th>\n <th>spectrogram_label_offset_seconds</th>\n <th>label_id</th>\n <th>patient_id</th>\n <th>expert_consensus</th>\n <th>seizure_vote</th>\n <th>lpd_vote</th>\n <th>gpd_vote</th>\n <th>lrda_vote</th>\n <th>grda_vote</th>\n <th>other_vote</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>1628180742</td>\n <td>0</td>\n <td>0.0</td>\n <td>353733</td>\n <td>0</td>\n <td>0.0</td>\n <td>127492639</td>\n <td>42516</td>\n <td>Seizure</td>\n <td>3</td>\n <td>0</td>\n <td>0</td>\n <td>0</td>\n <td>0</td>\n <td>0</td>\n </tr>\n <tr>\n <th>1</th>\n <td>1628180742</td>\n <td>1</td>\n <td>6.0</td>\n <td>353733</td>\n <td>1</td>\n <td>6.0</td>\n <td>3887563113</td>\n <td>42516</td>\n <td>Seizure</td>\n <td>3</td>\n <td>0</td>\n <td>0</td>\n <td>0</td>\n <td>0</td>\n <td>0</td>\n </tr>\n <tr>\n <th>2</th>\n <td>1628180742</td>\n <td>2</td>\n <td>8.0</td>\n <td>353733</td>\n <td>2</td>\n <td>8.0</td>\n <td>1142670488</td>\n <td>42516</td>\n <td>Seizure</td>\n <td>3</td>\n <td>0</td>\n <td>0</td>\n <td>0</td>\n <td>0</td>\n <td>0</td>\n </tr>\n <tr>\n <th>3</th>\n <td>1628180742</td>\n <td>3</td>\n <td>18.0</td>\n <td>353733</td>\n <td>3</td>\n <td>18.0</td>\n <td>2718991173</td>\n <td>42516</td>\n <td>Seizure</td>\n <td>3</td>\n <td>0</td>\n <td>0</td>\n <td>0</td>\n <td>0</td>\n <td>0</td>\n </tr>\n <tr>\n <th>4</th>\n <td>1628180742</td>\n <td>4</td>\n <td>24.0</td>\n <td>353733</td>\n <td>4</td>\n <td>24.0</td>\n <td>3080632009</td>\n <td>42516</td>\n <td>Seizure</td>\n <td>3</td>\n <td>0</td>\n <td>0</td>\n <td>0</td>\n <td>0</td>\n <td>0</td>\n </tr>\n <tr>\n <th>...</th>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n </tr>\n <tr>\n <th>106795</th>\n <td>351917269</td>\n <td>6</td>\n <td>12.0</td>\n <td>2147388374</td>\n <td>6</td>\n <td>12.0</td>\n <td>4195677307</td>\n <td>10351</td>\n <td>LRDA</td>\n <td>0</td>\n <td>0</td>\n <td>0</td>\n <td>3</td>\n <td>0</td>\n <td>0</td>\n </tr>\n <tr>\n <th>106796</th>\n <td>351917269</td>\n <td>7</td>\n <td>14.0</td>\n <td>2147388374</td>\n <td>7</td>\n <td>14.0</td>\n <td>290896675</td>\n <td>10351</td>\n <td>LRDA</td>\n <td>0</td>\n <td>0</td>\n <td>0</td>\n <td>3</td>\n <td>0</td>\n <td>0</td>\n </tr>\n <tr>\n <th>106797</th>\n <td>351917269</td>\n <td>8</td>\n <td>16.0</td>\n <td>2147388374</td>\n <td>8</td>\n <td>16.0</td>\n <td>461435451</td>\n <td>10351</td>\n <td>LRDA</td>\n <td>0</td>\n <td>0</td>\n <td>0</td>\n <td>3</td>\n <td>0</td>\n <td>0</td>\n </tr>\n <tr>\n <th>106798</th>\n <td>351917269</td>\n <td>9</td>\n <td>18.0</td>\n <td>2147388374</td>\n <td>9</td>\n <td>18.0</td>\n <td>3786213131</td>\n <td>10351</td>\n <td>LRDA</td>\n <td>0</td>\n <td>0</td>\n <td>0</td>\n <td>3</td>\n <td>0</td>\n <td>0</td>\n </tr>\n <tr>\n <th>106799</th>\n <td>351917269</td>\n <td>10</td>\n <td>20.0</td>\n <td>2147388374</td>\n <td>10</td>\n <td>20.0</td>\n <td>3642716176</td>\n <td>10351</td>\n <td>LRDA</td>\n <td>0</td>\n <td>0</td>\n <td>0</td>\n <td>3</td>\n <td>0</td>\n <td>0</td>\n </tr>\n </tbody>\n</table>\n<p>106800 rows × 15 columns</p>\n</div>"},"metadata":{}}]},{"cell_type":"code","source":"EEG_IDS = df.eeg_id.unique()\n\nTARGETS = df.columns[-6:]\nTARS = {'Seizure':0, 'LPD':1, 'GPD':2, 'LRDA':3, 'GRDA':4, 'Other':5}\nTARS_INV = {x:y for y,x in TARS.items()}\n\ntrain = df.groupby('eeg_id')[['patient_id']].agg('first')\ntmp = df.groupby('eeg_id')[TARGETS].agg('sum')\nfor t in TARGETS:\n train[t] = tmp[t].values\n \ny_data = train[TARGETS].values\ny_data = y_data / y_data.sum(axis=1,keepdims=True)\ntrain[TARGETS] = y_data\n\ntmp = df.groupby('eeg_id')[['expert_consensus']].agg('first')\ntrain['target'] = tmp\n\ntrain = train.reset_index()\ntrain = train.loc[train.eeg_id.isin(EEG_IDS)]\nprint('Train Data with unique eeg_id shape:', train.shape )","metadata":{"execution":{"iopub.status.busy":"2024-05-07T10:59:49.653510Z","iopub.execute_input":"2024-05-07T10:59:49.653888Z","iopub.status.idle":"2024-05-07T10:59:49.721877Z","shell.execute_reply.started":"2024-05-07T10:59:49.653859Z","shell.execute_reply":"2024-05-07T10:59:49.720960Z"},"trusted":true},"execution_count":7,"outputs":[{"name":"stdout","text":"Train Data with unique eeg_id shape: (17089, 9)\n","output_type":"stream"}]},{"cell_type":"code","source":"def eeg_from_parquet(parquet_path, display=False):\n \n # EXTRACT MIDDLE 50 SECONDS\n eeg = pd.read_parquet(parquet_path)\n rows = len(eeg)\n offset = (rows-10_000)//2\n eeg = eeg.iloc[offset:offset+10_000]\n \n if display: \n plt.figure(figsize=(10,5))\n offset = 0\n \n # CONVERT TO NUMPY\n data = np.zeros((10_000,len(eeg.columns) - 1))\n for j,col in enumerate(eeg.columns):\n if col == 'EKG':\n continue\n \n x = eeg[col].values.astype('float32')\n m = np.nanmean(x)\n if np.isnan(x).mean()<1: x = np.nan_to_num(x,nan=m)\n else: x[:] = 0\n \n data[:,j] = x\n \n if display: \n if j!=0: offset += x.max()\n plt.plot(range(10_000),x-offset,label=col)\n offset -= x.min()\n \n if display:\n plt.legend()\n name = parquet_path.split('/')[-1]\n name = name.split('.')[0]\n plt.title(f'EEG {name}',size=16)\n plt.show()\n \n return data","metadata":{"execution":{"iopub.status.busy":"2024-05-07T10:59:53.328405Z","iopub.execute_input":"2024-05-07T10:59:53.328746Z","iopub.status.idle":"2024-05-07T10:59:53.338818Z","shell.execute_reply.started":"2024-05-07T10:59:53.328721Z","shell.execute_reply":"2024-05-07T10:59:53.337787Z"},"trusted":true},"execution_count":8,"outputs":[]},{"cell_type":"markdown","source":"### Example of one-row-analysis","metadata":{}},{"cell_type":"code","source":"def get_one_row_data(row_id, mode='train'):\n row = train.iloc[row_id]\n \n path = EEG_PATH_TRAIN if mode == 'train' else EEG_PATH_TEST\n \n \n return eeg_from_parquet(f'{path}{row.eeg_id}.parquet')\n eeg = pd.read_parquet(f'{path}{row.eeg_id}.parquet')\n eeg_offset = int(row.eeg_label_offset_seconds)\n eeg = eeg.iloc[eeg_offset*200:(eeg_offset+50)*200]\n \n ekg = torch.Tensor(eeg['EKG'].to_numpy())\n eeg = torch.Tensor(eeg.drop(columns='EKG').to_numpy())\n \n return (eeg, ekg)","metadata":{"execution":{"iopub.status.busy":"2024-05-07T10:59:57.700364Z","iopub.execute_input":"2024-05-07T10:59:57.701243Z","iopub.status.idle":"2024-05-07T10:59:57.707439Z","shell.execute_reply.started":"2024-05-07T10:59:57.701208Z","shell.execute_reply":"2024-05-07T10:59:57.706482Z"},"trusted":true},"execution_count":9,"outputs":[]},{"cell_type":"code","source":"row_id = 5859\nnp.isnan(get_one_row_data(row_id)).any()","metadata":{"execution":{"iopub.status.busy":"2024-05-07T11:00:02.032388Z","iopub.execute_input":"2024-05-07T11:00:02.033191Z","iopub.status.idle":"2024-05-07T11:00:02.215237Z","shell.execute_reply.started":"2024-05-07T11:00:02.033159Z","shell.execute_reply":"2024-05-07T11:00:02.214377Z"},"trusted":true},"execution_count":10,"outputs":[{"execution_count":10,"output_type":"execute_result","data":{"text/plain":"False"},"metadata":{}}]},{"cell_type":"markdown","source":"### Let's make a graph from each row data","metadata":{}},{"cell_type":"code","source":"class EEGDataset(Dataset):\n def __init__(self, df, scaler=StandardScaler(), transform=ToDynamicG(\n edge_func='absolute_pearson_correlation_coefficient', threshold=0.7, binary=True)):\n \n self.df = df\n self.path = EEG_PATH_TRAIN\n self.scaler = scaler\n self.transform = transform\n \n def __len__(self):\n return len(self.df)\n \n def __getitem__(self, index):\n row = self.df.iloc[index]\n\n eeg = eeg_from_parquet(f'{self.path}{row.eeg_id}.parquet')\n #eeg = self.scaler.fit_transform(eeg)\n eeg = torch.Tensor(eeg)\n graph = self.transform(eeg=eeg.T)['eeg']\n \n y = np.array(row[-7:-1].values, 'float32').reshape(1,-1)\n y = y / y.sum(axis=1, keepdims=True)\n y = torch.Tensor(y)\n graph.update({'y': y})\n\n return graph","metadata":{"execution":{"iopub.status.busy":"2024-05-07T11:00:04.905876Z","iopub.execute_input":"2024-05-07T11:00:04.906531Z","iopub.status.idle":"2024-05-07T11:00:04.915594Z","shell.execute_reply.started":"2024-05-07T11:00:04.906498Z","shell.execute_reply":"2024-05-07T11:00:04.914497Z"},"trusted":true},"execution_count":11,"outputs":[]},{"cell_type":"code","source":"transforms = Compose([\n BandDifferentialEntropy(),\n MeanStdNormalize(),\n ToDynamicG(edge_func='absolute_pearson_correlation_coefficient', threshold=0.5, binary=True)\n])","metadata":{"execution":{"iopub.status.busy":"2024-05-07T11:00:10.041404Z","iopub.execute_input":"2024-05-07T11:00:10.041778Z","iopub.status.idle":"2024-05-07T11:00:10.046750Z","shell.execute_reply.started":"2024-05-07T11:00:10.041751Z","shell.execute_reply":"2024-05-07T11:00:10.045823Z"},"trusted":true},"execution_count":12,"outputs":[]},{"cell_type":"markdown","source":"## 5859","metadata":{}},{"cell_type":"code","source":"train.iloc[5859]","metadata":{"execution":{"iopub.status.busy":"2024-05-07T11:00:13.889245Z","iopub.execute_input":"2024-05-07T11:00:13.889665Z","iopub.status.idle":"2024-05-07T11:00:13.898482Z","shell.execute_reply.started":"2024-05-07T11:00:13.889635Z","shell.execute_reply":"2024-05-07T11:00:13.897359Z"},"trusted":true},"execution_count":13,"outputs":[{"execution_count":13,"output_type":"execute_result","data":{"text/plain":"eeg_id 1457334423\npatient_id 30631\nseizure_vote 0.0\nlpd_vote 0.0\ngpd_vote 0.0\nlrda_vote 0.0\ngrda_vote 0.0\nother_vote 1.0\ntarget Other\nName: 5859, dtype: object"},"metadata":{}}]},{"cell_type":"code","source":"x = get_one_row_data(5859)\nde = BandDifferentialEntropy()\nde_x = de(eeg=x.T)['eeg']\nde_x_torch = torch.Tensor(de_x)\nt = ToDynamicG(edge_func='absolute_pearson_correlation_coefficient', threshold=0.5, binary=True)\nt(eeg=de_x_torch)","metadata":{"execution":{"iopub.status.busy":"2024-05-07T11:00:16.099497Z","iopub.execute_input":"2024-05-07T11:00:16.100424Z","iopub.status.idle":"2024-05-07T11:00:16.276394Z","shell.execute_reply.started":"2024-05-07T11:00:16.100381Z","shell.execute_reply":"2024-05-07T11:00:16.275273Z"},"trusted":true},"execution_count":14,"outputs":[{"execution_count":14,"output_type":"execute_result","data":{"text/plain":"{'eeg': Data(edge_index=[2, 361], x=[19, 4], edge_weight=[361])}"},"metadata":{}}]},{"cell_type":"code","source":"dataset = EEGDataset(train, transform=transforms)\ndataset.__getitem__(5859).x","metadata":{"execution":{"iopub.status.busy":"2024-05-07T11:00:18.228274Z","iopub.execute_input":"2024-05-07T11:00:18.228991Z","iopub.status.idle":"2024-05-07T11:00:18.364727Z","shell.execute_reply.started":"2024-05-07T11:00:18.228960Z","shell.execute_reply":"2024-05-07T11:00:18.363631Z"},"trusted":true},"execution_count":15,"outputs":[{"execution_count":15,"output_type":"execute_result","data":{"text/plain":"tensor([[ 1.3566, 0.2153, -0.1234, -1.4485],\n [ 1.3566, 0.2153, -0.1234, -1.4485],\n [ 1.3566, 0.2153, -0.1234, -1.4485],\n [ 1.3566, 0.2153, -0.1234, -1.4485],\n [ 1.3566, 0.2153, -0.1234, -1.4485],\n [ 1.3566, 0.2153, -0.1234, -1.4485],\n [ 1.3566, 0.2153, -0.1234, -1.4485],\n [ 1.3566, 0.2153, -0.1234, -1.4485],\n [ 1.3566, 0.2153, -0.1234, -1.4485],\n [ 1.3566, 0.2153, -0.1234, -1.4485],\n [ 1.3566, 0.2153, -0.1234, -1.4485],\n [ 1.3566, 0.2153, -0.1234, -1.4485],\n [ 1.3566, 0.2153, -0.1234, -1.4485],\n [ 1.3566, 0.2153, -0.1234, -1.4485],\n [ 1.3566, 0.2153, -0.1234, -1.4485],\n [ 1.3566, 0.2153, -0.1234, -1.4485],\n [ 1.3566, 0.2153, -0.1234, -1.4485],\n [ 1.3566, 0.2153, -0.1234, -1.4485],\n [ 1.3566, 0.2153, -0.1234, -1.4485]])"},"metadata":{}}]},{"cell_type":"code","source":"def last_relevant_pytorch(output, lengths, batch_first=True):\n lengths = lengths.cpu()\n\n # masks of the true seq lengths\n masks = (lengths - 1).view(-1, 1).expand(len(lengths), output.size(2))\n time_dimension = 1 if batch_first else 0\n masks = masks.unsqueeze(time_dimension)\n masks = masks.to(output.device)\n last_output = output.gather(time_dimension, masks).squeeze(time_dimension)\n last_output.to(output.device)\n\n return last_output","metadata":{"execution":{"iopub.status.busy":"2024-05-07T11:05:22.496887Z","iopub.execute_input":"2024-05-07T11:05:22.497663Z","iopub.status.idle":"2024-05-07T11:05:22.503938Z","shell.execute_reply.started":"2024-05-07T11:05:22.497628Z","shell.execute_reply":"2024-05-07T11:05:22.502987Z"},"trusted":true},"execution_count":18,"outputs":[]},{"cell_type":"markdown","source":"## Create GNN","metadata":{}},{"cell_type":"code","source":"import torch\nimport torch.nn as nn\nimport sys\n\nclass CNN_LSTM(nn.Module):\n def __init__(self, num_classes=1):\n super(CNN_LSTM, self).__init__()\n self.num_classes = num_classes\n \n self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3)\n self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3)\n self.pool = nn.MaxPool2d(kernel_size=2)\n self.fc1 = nn.Linear(32*48*7, 512)\n\n self.lstm = nn.LSTM(input_size=512, hidden_size=128, num_layers=2)\n self.fc2 = nn.Linear(128, num_classes)\n\n def forward(self, x, seq_lengths):\n\n batch, max_seq_len, num_ch, in_dim = x.shape\n x = x.reshape(-1, num_ch, in_dim).unsqueeze(1)\n\n out = self.conv1(x)\n out = self.conv2(out)\n out = self.pool(out)\n\n out = out.reshape(batch*max_seq_len, -1)\n out = self.fc1(out)\n out = out.reshape(batch, max_seq_len, -1)\n\n lstm_out, _ = self.lstm(out)\n lstm_out = last_relevant_pytorch(lstm_out, seq_lengths, batch_first=True)\n logits = self.fc2(lstm_out)\n\n return logits","metadata":{"execution":{"iopub.status.busy":"2024-05-07T11:05:24.898095Z","iopub.execute_input":"2024-05-07T11:05:24.898702Z","iopub.status.idle":"2024-05-07T11:05:24.908657Z","shell.execute_reply.started":"2024-05-07T11:05:24.898665Z","shell.execute_reply":"2024-05-07T11:05:24.907679Z"},"trusted":true},"execution_count":19,"outputs":[]},{"cell_type":"markdown","source":"## Train","metadata":{}},{"cell_type":"code","source":"def criterion(logit, target):\n log_prob = F.log_softmax(logit, dim=1)\n return F.kl_div(log_prob, target, reduction=\"batchmean\")\n\ndef KL_loss(p,q):\n epsilon=10**(-15)\n p=torch.clip(p,epsilon,1-epsilon)\n q = nn.functional.log_softmax(q,dim=1)\n return torch.mean(torch.sum(p*(torch.log(p)-q),dim=1))\n\ndef compute_loss(model, data_loader):\n model.eval()\n l_loss = []\n with torch.no_grad():\n for data in data_loader:\n data.to(config['device'])\n y_pred = model(data)\n loss = criterion(y_pred, data.y)\n l_loss.append(loss.item())\n return np.mean(l_loss) ","metadata":{"execution":{"iopub.status.busy":"2024-05-07T11:05:28.131591Z","iopub.execute_input":"2024-05-07T11:05:28.132331Z","iopub.status.idle":"2024-05-07T11:05:28.140023Z","shell.execute_reply.started":"2024-05-07T11:05:28.132285Z","shell.execute_reply":"2024-05-07T11:05:28.139031Z"},"trusted":true},"execution_count":20,"outputs":[]},{"cell_type":"code","source":"%%time\nkf = KFold(n_splits=config['n_fold'], shuffle=True, random_state=config['seed'])\n\nl_best_loss = []\n\nfor fold, (iloc_train, iloc_valid) in enumerate(kf.split(train)):\n print(f\"Fold {fold}:\")\n \n train_ds = EEGDataset(df=train.iloc[iloc_train], transform=transforms)\n valid_ds = EEGDataset(df=train.iloc[iloc_valid], transform=transforms)\n train_loader = DataLoader(dataset=train_ds, shuffle=True, batch_size=config['batch_size'], \n num_workers=os.cpu_count(), drop_last=True)\n valid_loader = DataLoader(dataset=valid_ds, batch_size=config['batch_size'], \n num_workers=os.cpu_count())\n \n model = CNN_LSTM().to(config['device'])#EEG_Classifier(config['hidden_channels']).to(config['device'])\n optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], \n weight_decay=config['weight_decay'])\n scheduler = CosineAnnealingLR(optimizer=optimizer, T_max=config['epochs'])\n \n best_loss = float(\"inf\")\n history = []\n \n for epoch in range(config['epochs']):\n model.train()\n l_loss = []\n for data in train_loader:\n data.to(config['device'])\n optimizer.zero_grad()\n out = model(data)\n loss = criterion(out, data.y)\n l_loss.append(loss.item())\n\n print(f\"epoch={epoch}\\t loss={loss}\")\n loss.backward()\n optimizer.step()\n train_loss = np.mean(l_loss)\n valid_loss = compute_loss(model, valid_loader)\n history.append((epoch, train_loss, valid_loss))\n print(f\"Epoch {epoch}\")\n print(f\"Train Loss: {train_loss:>10.6f}, Valid Loss: {valid_loss:>10.6}\")\n\n if valid_loss < best_loss:\n print(f\"Loss improves from {best_loss:>10.6f} to {valid_loss:>10.6}\")\n torch.save(model.state_dict(), f\"{'basic_GNN'}__{fold}.pt\")\n best_loss = valid_loss\n print(f\"\\nBest loss Model training with {best_loss}\\n\")\n l_best_loss.append(best_loss)\n \n history = pd.DataFrame(history, columns=[\"epoch\", \"loss\", \"val_loss\"]).set_index(\"epoch\")\n history.plot(subplots=True, layout=(1, 2), sharey=\"row\", figsize=(14, 6))\n plt.show()\n","metadata":{"execution":{"iopub.status.busy":"2024-05-07T11:09:35.101126Z","iopub.execute_input":"2024-05-07T11:09:35.101596Z"},"trusted":true},"execution_count":null,"outputs":[{"name":"stdout","text":"Fold 0:\n","output_type":"stream"}]}]}