diff --git a/examples/chestxray14_binary_classification.ipynb b/examples/chestxray14_binary_classification.ipynb deleted file mode 100644 index 270fa3af7..000000000 --- a/examples/chestxray14_binary_classification.ipynb +++ /dev/null @@ -1,1426 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [], - "gpuType": "T4" - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - }, - "accelerator": "GPU", - "widgets": { - "application/vnd.jupyter.widget-state+json": { - "2750023fb2bc420c875b3fde2cef2843": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HBoxModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_d942639eecdf4f3c955a5ceabb2dd012", - "IPY_MODEL_21aed97b90ea496dafbbfde642b8da3d", - "IPY_MODEL_c38852bb4f82461dbc2f2ad28949e04f" - ], - "layout": "IPY_MODEL_36b15a47acca4d19ad19aa7a75d6adc4" - } - }, - "d942639eecdf4f3c955a5ceabb2dd012": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_489e34fc92e041f39771ff701ddd6969", - "placeholder": "​", - "style": "IPY_MODEL_00a24949f2324b8f855ec5bfdc92d434", - "value": "Epoch 0 / 1: 100%" - } - }, - "21aed97b90ea496dafbbfde642b8da3d": { - "model_module": "@jupyter-widgets/controls", - "model_name": "FloatProgressModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_ff1caf7de77b4b279a3b6d35669b79dd", - "max": 219, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_4c653f14317240edbea90f9b198c10cf", - "value": 219 - } - }, - "c38852bb4f82461dbc2f2ad28949e04f": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_a5e6278ca52d47f1b638e9a94be80cbd", - "placeholder": "​", - "style": "IPY_MODEL_85cac2d243444b52a23f40b87d1b4023", - "value": " 219/219 [00:44<00:00,  5.12it/s]" - } - }, - "36b15a47acca4d19ad19aa7a75d6adc4": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "489e34fc92e041f39771ff701ddd6969": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "00a24949f2324b8f855ec5bfdc92d434": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "ff1caf7de77b4b279a3b6d35669b79dd": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "4c653f14317240edbea90f9b198c10cf": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ProgressStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "a5e6278ca52d47f1b638e9a94be80cbd": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "85cac2d243444b52a23f40b87d1b4023": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - } - } - } - }, - "cells": [ - { - "cell_type": "markdown", - "source": [ - "# Binary Classification Using the ChestX-ray14 Dataset" - ], - "metadata": { - "id": "HaDNCcQJ3tD7" - } - }, - { - "cell_type": "markdown", - "source": [ - "## Step 0: Install PyHealth" - ], - "metadata": { - "id": "j9Zj-n54qEwL" - } - }, - { - "cell_type": "code", - "source": [ - "!rm -rf PyHealth\n", - "!git clone https://github.com/EricSchrock/PyHealth.git\n", - "%cd PyHealth\n", - "!git checkout ChestX-ray14\n", - "!pip install -e ." - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 1000 - }, - "collapsed": true, - "id": "TWEAeB85p0C7", - "outputId": "32a89b86-4c11-49ca-9c46-867dbfaf2fa7" - }, - "execution_count": 1, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Cloning into 'PyHealth'...\n", - "remote: Enumerating objects: 8101, done.\u001b[K\n", - "remote: Counting objects: 100% (1761/1761), done.\u001b[K\n", - "remote: Compressing objects: 100% (512/512), done.\u001b[K\n", - "remote: Total 8101 (delta 1555), reused 1251 (delta 1249), pack-reused 6340 (from 2)\u001b[K\n", - "Receiving objects: 100% (8101/8101), 113.88 MiB | 26.69 MiB/s, done.\n", - "Resolving deltas: 100% (5242/5242), done.\n", - "/content/PyHealth\n", - "Branch 'ChestX-ray14' set up to track remote branch 'ChestX-ray14' from 'origin'.\n", - "Switched to a new branch 'ChestX-ray14'\n", - "Obtaining file:///content/PyHealth\n", - " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", - " Checking if build backend supports build_editable ... \u001b[?25l\u001b[?25hdone\n", - " Getting requirements to build editable ... \u001b[?25l\u001b[?25hdone\n", - " Installing backend dependencies ... \u001b[?25l\u001b[?25hdone\n", - " Preparing editable metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "Requirement already satisfied: accelerate in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (1.11.0)\n", - "Collecting mne~=1.10.0 (from pyhealth==2.0a8)\n", - " Downloading mne-1.10.2-py3-none-any.whl.metadata (21 kB)\n", - "Requirement already satisfied: networkx in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (3.5)\n", - "Collecting numpy~=1.26.4 (from pyhealth==2.0a8)\n", - " Downloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m61.0/61.0 kB\u001b[0m \u001b[31m4.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting ogb>=1.3.5 (from pyhealth==2.0a8)\n", - " Downloading ogb-1.3.6-py3-none-any.whl.metadata (6.2 kB)\n", - "Collecting pandarallel~=1.6.5 (from pyhealth==2.0a8)\n", - " Downloading pandarallel-1.6.5.tar.gz (14 kB)\n", - " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - "Collecting pandas~=2.3.1 (from pyhealth==2.0a8)\n", - " Downloading pandas-2.3.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (91 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m91.2/91.2 kB\u001b[0m \u001b[31m9.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: peft in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (0.17.1)\n", - "Requirement already satisfied: polars~=1.31.0 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (1.31.0)\n", - "Requirement already satisfied: pydantic~=2.11.7 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (2.11.10)\n", - "Collecting rdkit (from pyhealth==2.0a8)\n", - " Downloading rdkit-2025.9.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (4.1 kB)\n", - "Collecting scikit-learn~=1.7.0 (from pyhealth==2.0a8)\n", - " Downloading scikit_learn-1.7.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (11 kB)\n", - "Requirement already satisfied: torchvision in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (0.23.0+cu126)\n", - "Collecting torch~=2.7.1 (from pyhealth==2.0a8)\n", - " Downloading torch-2.7.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (29 kB)\n", - "Requirement already satisfied: tqdm in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (4.67.1)\n", - "Collecting transformers~=4.53.2 (from pyhealth==2.0a8)\n", - " Downloading transformers-4.53.3-py3-none-any.whl.metadata (40 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.9/40.9 kB\u001b[0m \u001b[31m2.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: urllib3~=2.5.0 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (2.5.0)\n", - "Requirement already satisfied: decorator in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (4.4.2)\n", - "Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (3.1.6)\n", - "Requirement already satisfied: lazy-loader>=0.3 in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (0.4)\n", - "Requirement already satisfied: matplotlib>=3.7 in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (3.10.0)\n", - "Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (25.0)\n", - "Requirement already satisfied: pooch>=1.5 in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (1.8.2)\n", - "Requirement already satisfied: scipy>=1.11 in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (1.16.3)\n", - "Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.12/dist-packages (from ogb>=1.3.5->pyhealth==2.0a8) (1.17.0)\n", - "Collecting outdated>=0.2.0 (from ogb>=1.3.5->pyhealth==2.0a8)\n", - " Downloading outdated-0.2.2-py2.py3-none-any.whl.metadata (4.7 kB)\n", - "Requirement already satisfied: dill>=0.3.1 in /usr/local/lib/python3.12/dist-packages (from pandarallel~=1.6.5->pyhealth==2.0a8) (0.3.8)\n", - "Requirement already satisfied: psutil in /usr/local/lib/python3.12/dist-packages (from pandarallel~=1.6.5->pyhealth==2.0a8) (5.9.5)\n", - "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas~=2.3.1->pyhealth==2.0a8) (2.9.0.post0)\n", - "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas~=2.3.1->pyhealth==2.0a8) (2025.2)\n", - "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas~=2.3.1->pyhealth==2.0a8) (2025.2)\n", - "Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.12/dist-packages (from pydantic~=2.11.7->pyhealth==2.0a8) (0.7.0)\n", - "Requirement already satisfied: pydantic-core==2.33.2 in /usr/local/lib/python3.12/dist-packages (from pydantic~=2.11.7->pyhealth==2.0a8) (2.33.2)\n", - "Requirement already satisfied: typing-extensions>=4.12.2 in /usr/local/lib/python3.12/dist-packages (from pydantic~=2.11.7->pyhealth==2.0a8) (4.15.0)\n", - "Requirement already satisfied: typing-inspection>=0.4.0 in /usr/local/lib/python3.12/dist-packages (from pydantic~=2.11.7->pyhealth==2.0a8) (0.4.2)\n", - "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn~=1.7.0->pyhealth==2.0a8) (1.5.2)\n", - "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn~=1.7.0->pyhealth==2.0a8) (3.6.0)\n", - "Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (3.20.0)\n", - "Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (75.2.0)\n", - "Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (1.13.3)\n", - "Requirement already satisfied: fsspec in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (2025.3.0)\n", - "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.77)\n", - "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.77)\n", - "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.80)\n", - "Collecting nvidia-cudnn-cu12==9.5.1.17 (from torch~=2.7.1->pyhealth==2.0a8)\n", - " Downloading nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_x86_64.whl.metadata (1.6 kB)\n", - "Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.4.1)\n", - "Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (11.3.0.4)\n", - "Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (10.3.7.77)\n", - "Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (11.7.1.2)\n", - "Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.5.4.2)\n", - "Collecting nvidia-cusparselt-cu12==0.6.3 (from torch~=2.7.1->pyhealth==2.0a8)\n", - " Downloading nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_x86_64.whl.metadata (6.8 kB)\n", - "Collecting nvidia-nccl-cu12==2.26.2 (from torch~=2.7.1->pyhealth==2.0a8)\n", - " Downloading nvidia_nccl_cu12-2.26.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (2.0 kB)\n", - "Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.77)\n", - "Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.85)\n", - "Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (1.11.1.6)\n", - "Collecting triton==3.3.1 (from torch~=2.7.1->pyhealth==2.0a8)\n", - " Downloading triton-3.3.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (1.5 kB)\n", - "Requirement already satisfied: huggingface-hub<1.0,>=0.30.0 in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (0.36.0)\n", - "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (6.0.3)\n", - "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (2024.11.6)\n", - "Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (2.32.4)\n", - "Collecting tokenizers<0.22,>=0.21 (from transformers~=4.53.2->pyhealth==2.0a8)\n", - " Downloading tokenizers-0.21.4-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)\n", - "Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (0.6.2)\n", - "Requirement already satisfied: Pillow in /usr/local/lib/python3.12/dist-packages (from rdkit->pyhealth==2.0a8) (11.3.0)\n", - "INFO: pip is looking at multiple versions of torchvision to determine which version is compatible with other requirements. This could take a while.\n", - "Collecting torchvision (from pyhealth==2.0a8)\n", - " Downloading torchvision-0.24.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (5.9 kB)\n", - " Downloading torchvision-0.24.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (5.9 kB)\n", - " Downloading torchvision-0.23.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (6.1 kB)\n", - " Downloading torchvision-0.22.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (6.1 kB)\n", - "Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<1.0,>=0.30.0->transformers~=4.53.2->pyhealth==2.0a8) (1.2.0)\n", - "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne~=1.10.0->pyhealth==2.0a8) (1.3.3)\n", - "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne~=1.10.0->pyhealth==2.0a8) (0.12.1)\n", - "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne~=1.10.0->pyhealth==2.0a8) (4.60.1)\n", - "Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne~=1.10.0->pyhealth==2.0a8) (1.4.9)\n", - "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne~=1.10.0->pyhealth==2.0a8) (3.2.5)\n", - "Collecting littleutils (from outdated>=0.2.0->ogb>=1.3.5->pyhealth==2.0a8)\n", - " Downloading littleutils-0.2.4-py3-none-any.whl.metadata (679 bytes)\n", - "Requirement already satisfied: platformdirs>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from pooch>=1.5->mne~=1.10.0->pyhealth==2.0a8) (4.5.0)\n", - "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->transformers~=4.53.2->pyhealth==2.0a8) (3.4.4)\n", - "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests->transformers~=4.53.2->pyhealth==2.0a8) (3.11)\n", - "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests->transformers~=4.53.2->pyhealth==2.0a8) (2025.10.5)\n", - "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch~=2.7.1->pyhealth==2.0a8) (1.3.0)\n", - "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->mne~=1.10.0->pyhealth==2.0a8) (3.0.3)\n", - "Downloading mne-1.10.2-py3-none-any.whl (7.4 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.4/7.4 MB\u001b[0m \u001b[31m100.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.0 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m18.0/18.0 MB\u001b[0m \u001b[31m125.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading ogb-1.3.6-py3-none-any.whl (78 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m78.8/78.8 kB\u001b[0m \u001b[31m8.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading pandas-2.3.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (12.4 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m12.4/12.4 MB\u001b[0m \u001b[31m145.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading scikit_learn-1.7.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (9.5 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m9.5/9.5 MB\u001b[0m \u001b[31m145.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading torch-2.7.1-cp312-cp312-manylinux_2_28_x86_64.whl (821.0 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m821.0/821.0 MB\u001b[0m \u001b[31m1.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_x86_64.whl (571.0 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m571.0/571.0 MB\u001b[0m \u001b[31m3.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_x86_64.whl (156.8 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m156.8/156.8 MB\u001b[0m \u001b[31m7.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading nvidia_nccl_cu12-2.26.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (201.3 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m201.3/201.3 MB\u001b[0m \u001b[31m6.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading triton-3.3.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (155.7 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m155.7/155.7 MB\u001b[0m \u001b[31m8.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading transformers-4.53.3-py3-none-any.whl (10.8 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.8/10.8 MB\u001b[0m \u001b[31m133.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading rdkit-2025.9.1-cp312-cp312-manylinux_2_28_x86_64.whl (36.2 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m36.2/36.2 MB\u001b[0m \u001b[31m20.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading torchvision-0.22.1-cp312-cp312-manylinux_2_28_x86_64.whl (7.5 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.5/7.5 MB\u001b[0m \u001b[31m84.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading outdated-0.2.2-py2.py3-none-any.whl (7.5 kB)\n", - "Downloading tokenizers-0.21.4-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.1/3.1 MB\u001b[0m \u001b[31m63.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading littleutils-0.2.4-py3-none-any.whl (8.1 kB)\n", - "Building wheels for collected packages: pyhealth, pandarallel\n", - " Building editable for pyhealth (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - " Created wheel for pyhealth: filename=pyhealth-2.0a8-py3-none-any.whl size=10674 sha256=958c7e0bd8938910e22eda0840e62272710f8cae2e42ad8531f1012a34cd222f\n", - " Stored in directory: /tmp/pip-ephem-wheel-cache-c1tiyeqt/wheels/1c/98/da/d6e74a692d0be5faeba6025d7302fd470b1ee8167b77261ad6\n", - " Building wheel for pandarallel (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Created wheel for pandarallel: filename=pandarallel-1.6.5-py3-none-any.whl size=16674 sha256=d2ad066c2563268e9811ae2c6adb46d872232aa5ad8891689caf3aea26b89d42\n", - " Stored in directory: /root/.cache/pip/wheels/46/f9/0d/40c9cd74a7cb8dc8fe57e8d6c3c19e2c730449c0d3f2bf66b5\n", - "Successfully built pyhealth pandarallel\n", - "Installing collected packages: nvidia-cusparselt-cu12, triton, nvidia-nccl-cu12, nvidia-cudnn-cu12, numpy, littleutils, rdkit, pandas, outdated, torch, tokenizers, scikit-learn, pandarallel, transformers, torchvision, ogb, mne, pyhealth\n", - " Attempting uninstall: nvidia-cusparselt-cu12\n", - " Found existing installation: nvidia-cusparselt-cu12 0.7.1\n", - " Uninstalling nvidia-cusparselt-cu12-0.7.1:\n", - " Successfully uninstalled nvidia-cusparselt-cu12-0.7.1\n", - " Attempting uninstall: triton\n", - " Found existing installation: triton 3.4.0\n", - " Uninstalling triton-3.4.0:\n", - " Successfully uninstalled triton-3.4.0\n", - " Attempting uninstall: nvidia-nccl-cu12\n", - " Found existing installation: nvidia-nccl-cu12 2.27.3\n", - " Uninstalling nvidia-nccl-cu12-2.27.3:\n", - " Successfully uninstalled nvidia-nccl-cu12-2.27.3\n", - " Attempting uninstall: nvidia-cudnn-cu12\n", - " Found existing installation: nvidia-cudnn-cu12 9.10.2.21\n", - " Uninstalling nvidia-cudnn-cu12-9.10.2.21:\n", - " Successfully uninstalled nvidia-cudnn-cu12-9.10.2.21\n", - " Attempting uninstall: numpy\n", - " Found existing installation: numpy 2.0.2\n", - " Uninstalling numpy-2.0.2:\n", - " Successfully uninstalled numpy-2.0.2\n", - " Attempting uninstall: pandas\n", - " Found existing installation: pandas 2.2.2\n", - " Uninstalling pandas-2.2.2:\n", - " Successfully uninstalled pandas-2.2.2\n", - " Attempting uninstall: torch\n", - " Found existing installation: torch 2.8.0+cu126\n", - " Uninstalling torch-2.8.0+cu126:\n", - " Successfully uninstalled torch-2.8.0+cu126\n", - " Attempting uninstall: tokenizers\n", - " Found existing installation: tokenizers 0.22.1\n", - " Uninstalling tokenizers-0.22.1:\n", - " Successfully uninstalled tokenizers-0.22.1\n", - " Attempting uninstall: scikit-learn\n", - " Found existing installation: scikit-learn 1.6.1\n", - " Uninstalling scikit-learn-1.6.1:\n", - " Successfully uninstalled scikit-learn-1.6.1\n", - " Attempting uninstall: transformers\n", - " Found existing installation: transformers 4.57.1\n", - " Uninstalling transformers-4.57.1:\n", - " Successfully uninstalled transformers-4.57.1\n", - " Attempting uninstall: torchvision\n", - " Found existing installation: torchvision 0.23.0+cu126\n", - " Uninstalling torchvision-0.23.0+cu126:\n", - " Successfully uninstalled torchvision-0.23.0+cu126\n", - "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", - "google-colab 1.0.0 requires pandas==2.2.2, but you have pandas 2.3.3 which is incompatible.\n", - "pytensor 2.35.1 requires numpy>=2.0, but you have numpy 1.26.4 which is incompatible.\n", - "opencv-contrib-python 4.12.0.88 requires numpy<2.3.0,>=2; python_version >= \"3.9\", but you have numpy 1.26.4 which is incompatible.\n", - "shap 0.50.0 requires numpy>=2, but you have numpy 1.26.4 which is incompatible.\n", - "torchaudio 2.8.0+cu126 requires torch==2.8.0, but you have torch 2.7.1 which is incompatible.\n", - "opencv-python 4.12.0.88 requires numpy<2.3.0,>=2; python_version >= \"3.9\", but you have numpy 1.26.4 which is incompatible.\n", - "jax 0.7.2 requires numpy>=2.0, but you have numpy 1.26.4 which is incompatible.\n", - "opencv-python-headless 4.12.0.88 requires numpy<2.3.0,>=2; python_version >= \"3.9\", but you have numpy 1.26.4 which is incompatible.\n", - "jaxlib 0.7.2 requires numpy>=2.0, but you have numpy 1.26.4 which is incompatible.\u001b[0m\u001b[31m\n", - "\u001b[0mSuccessfully installed littleutils-0.2.4 mne-1.10.2 numpy-1.26.4 nvidia-cudnn-cu12-9.5.1.17 nvidia-cusparselt-cu12-0.6.3 nvidia-nccl-cu12-2.26.2 ogb-1.3.6 outdated-0.2.2 pandarallel-1.6.5 pandas-2.3.3 pyhealth-2.0a8 rdkit-2025.9.1 scikit-learn-1.7.2 tokenizers-0.21.4 torch-2.7.1 torchvision-0.22.1 transformers-4.53.3 triton-3.3.1\n" - ] - }, - { - "output_type": "display_data", - "data": { - "application/vnd.colab-display-data+json": { - "pip_warning": { - "packages": [ - "numpy" - ] - }, - "id": "3737617eb2cf402699bacea64f559c14" - } - }, - "metadata": {} - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Step 1: Load Dataset" - ], - "metadata": { - "id": "rMjzPqNbscDV" - } - }, - { - "cell_type": "code", - "source": [ - "from pyhealth.datasets import ChestXray14Dataset\n", - "\n", - "dataset = ChestXray14Dataset(download=True, partial=True)\n", - "dataset.stats()" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "q_fTVUTrsryn", - "outputId": "0660d909-31c6-48df-bb98-a015e48dd88d" - }, - "execution_count": 1, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.chestxray14:Downloading ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Checking MD5 checksum for ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.chestxray14:Checking MD5 checksum for ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Extracting ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.chestxray14:Extracting ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Deleting ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.chestxray14:Deleting ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Download complete\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.chestxray14:Download complete\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Initializing ChestX-ray14 dataset from . (dev mode: False)\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.base_dataset:Initializing ChestX-ray14 dataset from . (dev mode: False)\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Scanning table: chestxray14 from /content/chestxray14-metadata-pyhealth.csv\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.base_dataset:Scanning table: chestxray14 from /content/chestxray14-metadata-pyhealth.csv\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Collecting global event dataframe...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.base_dataset:Collecting global event dataframe...\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Collected dataframe with shape: (4999, 26)\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.base_dataset:Collected dataframe with shape: (4999, 26)\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Dataset: ChestX-ray14\n", - "Dev mode: False\n", - "Number of patients: 1335\n", - "Number of events: 4999\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Step 2: Define Task" - ], - "metadata": { - "id": "ecF9IgCb22N5" - } - }, - { - "cell_type": "code", - "source": [ - "from pyhealth.tasks import ChestXray14BinaryClassification\n", - "\n", - "task = ChestXray14BinaryClassification(disease=\"infiltration\")\n", - "samples = dataset.set_task(task)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "uj9ALkQGtVqF", - "outputId": "bfb6c953-9411-40be-9ca0-060077cbad96" - }, - "execution_count": 2, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Setting task ChestXray14BinaryClassification for ChestX-ray14 base dataset...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.base_dataset:Setting task ChestXray14BinaryClassification for ChestX-ray14 base dataset...\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Generating samples with 1 worker(s)...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.base_dataset:Generating samples with 1 worker(s)...\n", - "Generating samples for ChestXray14BinaryClassification with 1 worker: 100%|██████████| 1335/1335 [00:00<00:00, 1770.85it/s]" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Label label vocab: {0: 0, 1: 1}\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "\n", - "INFO:pyhealth.processors.label_processor:Label label vocab: {0: 0, 1: 1}\n", - "Processing samples: 100%|██████████| 4999/4999 [01:22<00:00, 60.94it/s]" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Generated 4999 samples for task ChestXray14BinaryClassification\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "\n", - "INFO:pyhealth.datasets.base_dataset:Generated 4999 samples for task ChestXray14BinaryClassification\n" - ] - } - ] - }, - { - "cell_type": "code", - "source": [ - "from pyhealth.datasets import get_dataloader, split_by_sample\n", - "\n", - "train_dataset, val_dataset, test_dataset = split_by_sample(samples, [0.7, 0.1, 0.2])\n", - "\n", - "train_loader = get_dataloader(train_dataset, batch_size=16, shuffle=True)\n", - "val_loader = get_dataloader(val_dataset, batch_size=16, shuffle=False)\n", - "test_loader = get_dataloader(test_dataset, batch_size=16, shuffle=False)" - ], - "metadata": { - "id": "8qS3hfKX5GNo" - }, - "execution_count": 3, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "## Step 3: Define Model" - ], - "metadata": { - "id": "SjonWePy1r6N" - } - }, - { - "cell_type": "code", - "source": [ - "from pyhealth.models import CNN\n", - "\n", - "model = CNN(dataset=samples)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "VydSOr8u0XWG", - "outputId": "9f3bf251-0b5d-4457-9dfe-14f5d1beaeb5" - }, - "execution_count": 4, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "/content/PyHealth/pyhealth/metrics/calibration.py:122: SyntaxWarning: invalid escape sequence '\\c'\n", - " accuracy of 1. Thus, the ECE is :math:`\\\\frac{1}{3} \\cdot 0.49 + \\\\frac{2}{3}\\cdot 0.3=0.3633`.\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Warning: No embedding created for field due to lack of compatible processor: image\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Step 4: Train Model" - ], - "metadata": { - "id": "0jqDpKxgAu3-" - } - }, - { - "cell_type": "code", - "source": [ - "from pyhealth.trainer import Trainer\n", - "\n", - "trainer = Trainer(model=model)\n", - "trainer.train(train_dataloader=train_loader, val_dataloader=val_loader, epochs=1)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 1000, - "referenced_widgets": [ - "2750023fb2bc420c875b3fde2cef2843", - "d942639eecdf4f3c955a5ceabb2dd012", - "21aed97b90ea496dafbbfde642b8da3d", - "c38852bb4f82461dbc2f2ad28949e04f", - "36b15a47acca4d19ad19aa7a75d6adc4", - "489e34fc92e041f39771ff701ddd6969", - "00a24949f2324b8f855ec5bfdc92d434", - "ff1caf7de77b4b279a3b6d35669b79dd", - "4c653f14317240edbea90f9b198c10cf", - "a5e6278ca52d47f1b638e9a94be80cbd", - "85cac2d243444b52a23f40b87d1b4023" - ] - }, - "id": "-our6gpdAyGD", - "outputId": "ac84d73f-9940-4333-8f7e-f7b9f2614da9" - }, - "execution_count": 5, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "CNN(\n", - " (embedding_model): EmbeddingModel(embedding_layers=ModuleDict())\n", - " (cnn): ModuleDict(\n", - " (image): CNNLayer(\n", - " (cnn): ModuleList(\n", - " (0): CNNBlock(\n", - " (conv1): Sequential(\n", - " (0): Conv2d(1, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (2): ReLU()\n", - " )\n", - " (conv2): Sequential(\n", - " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " (downsample): Sequential(\n", - " (0): Conv2d(1, 128, kernel_size=(1, 1), stride=(1, 1))\n", - " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " (relu): ReLU()\n", - " )\n", - " )\n", - " (pooling): AdaptiveAvgPool2d(output_size=1)\n", - " )\n", - " )\n", - " (fc): Linear(in_features=128, out_features=1, bias=True)\n", - ")\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:CNN(\n", - " (embedding_model): EmbeddingModel(embedding_layers=ModuleDict())\n", - " (cnn): ModuleDict(\n", - " (image): CNNLayer(\n", - " (cnn): ModuleList(\n", - " (0): CNNBlock(\n", - " (conv1): Sequential(\n", - " (0): Conv2d(1, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (2): ReLU()\n", - " )\n", - " (conv2): Sequential(\n", - " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " (downsample): Sequential(\n", - " (0): Conv2d(1, 128, kernel_size=(1, 1), stride=(1, 1))\n", - " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " (relu): ReLU()\n", - " )\n", - " )\n", - " (pooling): AdaptiveAvgPool2d(output_size=1)\n", - " )\n", - " )\n", - " (fc): Linear(in_features=128, out_features=1, bias=True)\n", - ")\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Metrics: None\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Metrics: None\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Device: cuda\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Device: cuda\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Training:\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Training:\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Batch size: 16\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Batch size: 16\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Optimizer: \n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Optimizer: \n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Optimizer params: {'lr': 0.001}\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Optimizer params: {'lr': 0.001}\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Weight decay: 0.0\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Weight decay: 0.0\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Max grad norm: None\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Max grad norm: None\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Val dataloader: \n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Val dataloader: \n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Monitor: None\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Monitor: None\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Monitor criterion: max\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Monitor criterion: max\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Epochs: 1\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Epochs: 1\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Patience: None\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Patience: None\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:\n" - ] - }, - { - "output_type": "display_data", - "data": { - "text/plain": [ - "Epoch 0 / 1: 0%| | 0/219 [00:00=1.3.5 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (1.3.6)\n", - "Requirement already satisfied: pandarallel~=1.6.5 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (1.6.5)\n", - "Requirement already satisfied: pandas~=2.3.1 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (2.3.3)\n", - "Requirement already satisfied: peft in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (0.17.1)\n", - "Requirement already satisfied: polars~=1.31.0 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (1.31.0)\n", - "Requirement already satisfied: pydantic~=2.11.7 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (2.11.10)\n", - "Requirement already satisfied: rdkit in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (2025.9.1)\n", - "Requirement already satisfied: scikit-learn~=1.7.0 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (1.7.2)\n", - "Requirement already satisfied: torchvision in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (0.22.1)\n", - "Requirement already satisfied: torch~=2.7.1 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (2.7.1)\n", - "Requirement already satisfied: tqdm in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (4.67.1)\n", - "Requirement already satisfied: transformers~=4.53.2 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (4.53.3)\n", - "Requirement already satisfied: urllib3~=2.5.0 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (2.5.0)\n", - "Requirement already satisfied: decorator in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (4.4.2)\n", - "Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (3.1.6)\n", - "Requirement already satisfied: lazy-loader>=0.3 in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (0.4)\n", - "Requirement already satisfied: matplotlib>=3.7 in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (3.10.0)\n", - "Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (25.0)\n", - "Requirement already satisfied: pooch>=1.5 in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (1.8.2)\n", - "Requirement already satisfied: scipy>=1.11 in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (1.16.3)\n", - "Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.12/dist-packages (from ogb>=1.3.5->pyhealth==2.0a8) (1.17.0)\n", - "Requirement already satisfied: outdated>=0.2.0 in /usr/local/lib/python3.12/dist-packages (from ogb>=1.3.5->pyhealth==2.0a8) (0.2.2)\n", - "Requirement already satisfied: dill>=0.3.1 in /usr/local/lib/python3.12/dist-packages (from pandarallel~=1.6.5->pyhealth==2.0a8) (0.3.8)\n", - "Requirement already satisfied: psutil in /usr/local/lib/python3.12/dist-packages (from pandarallel~=1.6.5->pyhealth==2.0a8) (5.9.5)\n", - "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas~=2.3.1->pyhealth==2.0a8) (2.9.0.post0)\n", - "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas~=2.3.1->pyhealth==2.0a8) (2025.2)\n", - "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas~=2.3.1->pyhealth==2.0a8) (2025.2)\n", - "Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.12/dist-packages (from pydantic~=2.11.7->pyhealth==2.0a8) (0.7.0)\n", - "Requirement already satisfied: pydantic-core==2.33.2 in /usr/local/lib/python3.12/dist-packages (from pydantic~=2.11.7->pyhealth==2.0a8) (2.33.2)\n", - "Requirement already satisfied: typing-extensions>=4.12.2 in /usr/local/lib/python3.12/dist-packages (from pydantic~=2.11.7->pyhealth==2.0a8) (4.15.0)\n", - "Requirement already satisfied: typing-inspection>=0.4.0 in /usr/local/lib/python3.12/dist-packages (from pydantic~=2.11.7->pyhealth==2.0a8) (0.4.2)\n", - "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn~=1.7.0->pyhealth==2.0a8) (1.5.2)\n", - "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn~=1.7.0->pyhealth==2.0a8) (3.6.0)\n", - "Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (3.20.0)\n", - "Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (75.2.0)\n", - "Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (1.13.3)\n", - "Requirement already satisfied: fsspec in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (2025.3.0)\n", - "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.77)\n", - "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.77)\n", - "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.80)\n", - "Requirement already satisfied: nvidia-cudnn-cu12==9.5.1.17 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (9.5.1.17)\n", - "Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.4.1)\n", - "Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (11.3.0.4)\n", - "Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (10.3.7.77)\n", - "Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (11.7.1.2)\n", - "Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.5.4.2)\n", - "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.3 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (0.6.3)\n", - "Requirement already satisfied: nvidia-nccl-cu12==2.26.2 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (2.26.2)\n", - "Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.77)\n", - "Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.85)\n", - "Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (1.11.1.6)\n", - "Requirement already satisfied: triton==3.3.1 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (3.3.1)\n", - "Requirement already satisfied: huggingface-hub<1.0,>=0.30.0 in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (0.36.0)\n", - "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (6.0.3)\n", - "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (2024.11.6)\n", - "Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (2.32.4)\n", - "Requirement already satisfied: tokenizers<0.22,>=0.21 in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (0.21.4)\n", - "Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (0.6.2)\n", - "Requirement already satisfied: Pillow in /usr/local/lib/python3.12/dist-packages (from rdkit->pyhealth==2.0a8) (11.3.0)\n", - "Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<1.0,>=0.30.0->transformers~=4.53.2->pyhealth==2.0a8) (1.2.0)\n", - "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne~=1.10.0->pyhealth==2.0a8) (1.3.3)\n", - "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne~=1.10.0->pyhealth==2.0a8) (0.12.1)\n", - "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne~=1.10.0->pyhealth==2.0a8) (4.60.1)\n", - "Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne~=1.10.0->pyhealth==2.0a8) (1.4.9)\n", - "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne~=1.10.0->pyhealth==2.0a8) (3.2.5)\n", - "Requirement already satisfied: littleutils in /usr/local/lib/python3.12/dist-packages (from outdated>=0.2.0->ogb>=1.3.5->pyhealth==2.0a8) (0.2.4)\n", - "Requirement already satisfied: platformdirs>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from pooch>=1.5->mne~=1.10.0->pyhealth==2.0a8) (4.5.0)\n", - "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->transformers~=4.53.2->pyhealth==2.0a8) (3.4.4)\n", - "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests->transformers~=4.53.2->pyhealth==2.0a8) (3.11)\n", - "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests->transformers~=4.53.2->pyhealth==2.0a8) (2025.10.5)\n", - "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch~=2.7.1->pyhealth==2.0a8) (1.3.0)\n", - "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->mne~=1.10.0->pyhealth==2.0a8) (3.0.3)\n", - "Building wheels for collected packages: pyhealth\n", - " Building editable for pyhealth (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - " Created wheel for pyhealth: filename=pyhealth-2.0a8-py3-none-any.whl size=10674 sha256=958c7e0bd8938910e22eda0840e62272710f8cae2e42ad8531f1012a34cd222f\n", - " Stored in directory: /tmp/pip-ephem-wheel-cache-netvrq88/wheels/1c/98/da/d6e74a692d0be5faeba6025d7302fd470b1ee8167b77261ad6\n", - "Successfully built pyhealth\n", - "Installing collected packages: pyhealth\n", - " Attempting uninstall: pyhealth\n", - " Found existing installation: pyhealth 2.0a8\n", - " Uninstalling pyhealth-2.0a8:\n", - " Successfully uninstalled pyhealth-2.0a8\n", - "Successfully installed pyhealth-2.0a8\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Step 1: Load Dataset" - ], - "metadata": { - "id": "rMjzPqNbscDV" - } - }, - { - "cell_type": "code", - "source": [ - "from pyhealth.datasets import ChestXray14Dataset\n", - "\n", - "dataset = ChestXray14Dataset(download=True, partial=True)\n", - "dataset.stats()" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "q_fTVUTrsryn", - "outputId": "942b186a-dc4d-4b05-eedd-c0d285aae951" - }, - "execution_count": 1, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.chestxray14:Downloading ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Checking MD5 checksum for ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.chestxray14:Checking MD5 checksum for ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Extracting ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.chestxray14:Extracting ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Deleting ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.chestxray14:Deleting ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Download complete\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.chestxray14:Download complete\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Initializing ChestX-ray14 dataset from . (dev mode: False)\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.base_dataset:Initializing ChestX-ray14 dataset from . (dev mode: False)\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Scanning table: chestxray14 from /content/chestxray14-metadata-pyhealth.csv\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.base_dataset:Scanning table: chestxray14 from /content/chestxray14-metadata-pyhealth.csv\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Collecting global event dataframe...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.base_dataset:Collecting global event dataframe...\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Collected dataframe with shape: (4999, 26)\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.base_dataset:Collected dataframe with shape: (4999, 26)\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Dataset: ChestX-ray14\n", - "Dev mode: False\n", - "Number of patients: 1335\n", - "Number of events: 4999\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Step 2: Define Task" - ], - "metadata": { - "id": "ecF9IgCb22N5" - } - }, - { - "cell_type": "code", - "source": [ - "samples = dataset.set_task()" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "uj9ALkQGtVqF", - "outputId": "076cbb31-879f-4414-963e-7e3631f0ed31" - }, - "execution_count": 2, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Setting task ChestXray14MultilabelClassification for ChestX-ray14 base dataset...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.base_dataset:Setting task ChestXray14MultilabelClassification for ChestX-ray14 base dataset...\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Generating samples with 1 worker(s)...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.base_dataset:Generating samples with 1 worker(s)...\n", - "Generating samples for ChestXray14MultilabelClassification with 1 worker: 100%|██████████| 1335/1335 [00:00<00:00, 1475.55it/s]" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Label labels vocab: {'atelectasis': 0, 'cardiomegaly': 1, 'consolidation': 2, 'edema': 3, 'effusion': 4, 'emphysema': 5, 'fibrosis': 6, 'hernia': 7, 'infiltration': 8, 'mass': 9, 'nodule': 10, 'pleural_thickening': 11, 'pneumonia': 12, 'pneumothorax': 13}\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "\n", - "INFO:pyhealth.processors.label_processor:Label labels vocab: {'atelectasis': 0, 'cardiomegaly': 1, 'consolidation': 2, 'edema': 3, 'effusion': 4, 'emphysema': 5, 'fibrosis': 6, 'hernia': 7, 'infiltration': 8, 'mass': 9, 'nodule': 10, 'pleural_thickening': 11, 'pneumonia': 12, 'pneumothorax': 13}\n", - "Processing samples: 100%|██████████| 4999/4999 [01:18<00:00, 63.31it/s]" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Generated 4999 samples for task ChestXray14MultilabelClassification\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "\n", - "INFO:pyhealth.datasets.base_dataset:Generated 4999 samples for task ChestXray14MultilabelClassification\n" - ] - } - ] - }, - { - "cell_type": "code", - "source": [ - "from pyhealth.datasets import get_dataloader, split_by_sample\n", - "\n", - "train_dataset, val_dataset, test_dataset = split_by_sample(samples, [0.7, 0.1, 0.2])\n", - "\n", - "train_loader = get_dataloader(train_dataset, batch_size=16, shuffle=True)\n", - "val_loader = get_dataloader(val_dataset, batch_size=16, shuffle=False)\n", - "test_loader = get_dataloader(test_dataset, batch_size=16, shuffle=False)" - ], - "metadata": { - "id": "8qS3hfKX5GNo" - }, - "execution_count": 3, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "## Step 3: Define Model" - ], - "metadata": { - "id": "SjonWePy1r6N" - } - }, - { - "cell_type": "code", - "source": [ - "from pyhealth.models import CNN\n", - "\n", - "model = CNN(dataset=samples)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "VydSOr8u0XWG", - "outputId": "52e4df8b-00fd-47b7-e3ce-5836df69ffa9" - }, - "execution_count": 4, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "/content/PyHealth/pyhealth/metrics/calibration.py:122: SyntaxWarning: invalid escape sequence '\\c'\n", - " accuracy of 1. Thus, the ECE is :math:`\\\\frac{1}{3} \\cdot 0.49 + \\\\frac{2}{3}\\cdot 0.3=0.3633`.\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Warning: No embedding created for field due to lack of compatible processor: image\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Step 4: Train Model" - ], - "metadata": { - "id": "0jqDpKxgAu3-" - } - }, - { - "cell_type": "code", - "source": [ - "from pyhealth.trainer import Trainer\n", - "\n", - "# Only measure accurancy because with the \"partial\" dataset it is likely that\n", - "# there are not positive samples of every label present in the validation and test sets\n", - "trainer = Trainer(model=model, metrics=[\"accuracy\"])\n", - "trainer.train(train_dataloader=train_loader, val_dataloader=val_loader, epochs=1)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 1000, - "referenced_widgets": [ - "d5764f3cccdf4c52a25d0b8b2071e3b3", - "775da3f0d3e643f793ba6ad6aefdefca", - "87c9b17add0b434a897d46ec46826b4c", - "cba94b14345e40f4806e26a41edb72bc", - "207a6f173e57485b9abb66eb5f259c74", - "b32f9870995c4af3890deb5af41a77e0", - "517d39e922b543b4a111c1c72dc2abbd", - "3c26b80083274382b36f31add64ed5ed", - "8a9b003734834976aafca099ab6a37a5", - "059bda91051a46c18e0b02aa11eb73fa", - "ab3c660b9c2449619cca4a9d31392391" - ] - }, - "id": "-our6gpdAyGD", - "outputId": "d7360434-f396-4ffc-d348-04bfc6c3a524" - }, - "execution_count": 5, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "CNN(\n", - " (embedding_model): EmbeddingModel(embedding_layers=ModuleDict())\n", - " (cnn): ModuleDict(\n", - " (image): CNNLayer(\n", - " (cnn): ModuleList(\n", - " (0): CNNBlock(\n", - " (conv1): Sequential(\n", - " (0): Conv2d(1, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (2): ReLU()\n", - " )\n", - " (conv2): Sequential(\n", - " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " (downsample): Sequential(\n", - " (0): Conv2d(1, 128, kernel_size=(1, 1), stride=(1, 1))\n", - " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " (relu): ReLU()\n", - " )\n", - " )\n", - " (pooling): AdaptiveAvgPool2d(output_size=1)\n", - " )\n", - " )\n", - " (fc): Linear(in_features=128, out_features=14, bias=True)\n", - ")\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:CNN(\n", - " (embedding_model): EmbeddingModel(embedding_layers=ModuleDict())\n", - " (cnn): ModuleDict(\n", - " (image): CNNLayer(\n", - " (cnn): ModuleList(\n", - " (0): CNNBlock(\n", - " (conv1): Sequential(\n", - " (0): Conv2d(1, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (2): ReLU()\n", - " )\n", - " (conv2): Sequential(\n", - " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " (downsample): Sequential(\n", - " (0): Conv2d(1, 128, kernel_size=(1, 1), stride=(1, 1))\n", - " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " (relu): ReLU()\n", - " )\n", - " )\n", - " (pooling): AdaptiveAvgPool2d(output_size=1)\n", - " )\n", - " )\n", - " (fc): Linear(in_features=128, out_features=14, bias=True)\n", - ")\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Metrics: ['accuracy']\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Metrics: ['accuracy']\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Device: cuda\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Device: cuda\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Training:\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Training:\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Batch size: 16\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Batch size: 16\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Optimizer: \n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Optimizer: \n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Optimizer params: {'lr': 0.001}\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Optimizer params: {'lr': 0.001}\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Weight decay: 0.0\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Weight decay: 0.0\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Max grad norm: None\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Max grad norm: None\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Val dataloader: \n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Val dataloader: \n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Monitor: None\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Monitor: None\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Monitor criterion: max\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Monitor criterion: max\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Epochs: 1\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Epochs: 1\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Patience: None\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Patience: None\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:\n" - ] - }, - { - "output_type": "display_data", - "data": { - "text/plain": [ - "Epoch 0 / 1: 0%| | 0/219 [00:00>> dataset = ChestXray14Dataset(root="./data") """ self._label_path: str = os.path.join(root, "Data_Entry_2017_v2020.csv") @@ -98,7 +98,7 @@ def default_task(self) -> ChestXray14MultilabelClassification: Returns: ChestXray14MultilabelClassification: The default classification task. - Example: + Example:: >>> dataset = ChestXray14Dataset() >>> task = dataset.default_task """ @@ -118,12 +118,20 @@ def set_task(self, *args, **kwargs): return super().set_task(*args, **kwargs) + set_task.__doc__ = ( + f"{set_task.__doc__}\n" + " Note:\n" + " If no image processor is provided, a default grayscale `ImageProcessor(mode='L')` is injected. " + "This is needed because the ChestX-ray14 dataset images do not all have the same number of channels, " + "causing the default PyHealth image processor to fail." + ) + def _download(self, root: str, partial: bool) -> None: """Downloads and verifies the ChestX-ray14 dataset files. This method performs the following steps: - 1. Downloads the label CSV file from a Google Drive mirror. - 2. Downloads compressed image archives from NIH Box links. + 1. Downloads the label CSV file from the shared NIH Box folder. + 2. Downloads compressed image archives from static NIH Box links. 3. Verifies the integrity of each downloaded file using its MD5 checksum. 4. Extracts the image archives to the dataset directory. 5. Removes the original compressed files after successful extraction. @@ -138,11 +146,18 @@ def _download(self, root: str, partial: bool) -> None: ValueError: If an image tar file contains an unsafe path. ValueError: If an unexpected number of images are downloaded. """ - # https://nihcc.app.box.com/v/ChestXray-NIHCC/file/219760887468 (mirrored to Google Drive) - # I couldn't figure out a way to download this file directly from box.com - response = requests.get('https://drive.google.com/uc?export=download&id=1mkOZNfYt-Px52b8CJZJANNbM3ULUVO3f') - with open(self._label_path, "wb") as file: - file.write(response.content) + response = requests.get( + url=( + "https://nihcc.app.box.com/index.php" + "?rm=box_download_shared_file" + "&vanity_name=ChestXray-NIHCC" + "&file_id=f_219760887468" + ), + allow_redirects=True, + ) + + with open(self._label_path, "wb") as f: + f.write(response.content) # https://nihcc.app.box.com/v/ChestXray-NIHCC/file/371647823217 links = [ diff --git a/test-resources/core/chestxray14/Data_Entry_2017_v2020.csv b/test-resources/core/chestxray14/Data_Entry_2017_v2020.csv new file mode 100644 index 000000000..8d41da5dd --- /dev/null +++ b/test-resources/core/chestxray14/Data_Entry_2017_v2020.csv @@ -0,0 +1,11 @@ +Image Index,Finding Labels,Follow-up #,Patient ID,Patient Age,Patient Sex,View Position,OriginalImage[Width,Height],OriginalImagePixelSpacing[x,y], +00000001_000.png,Cardiomegaly,0,1,57,M,PA,2682,2749,0.14300000000000002,0.14300000000000002, +00000001_001.png,Cardiomegaly|Emphysema,1,1,58,M,PA,2894,2729,0.14300000000000002,0.14300000000000002, +00000001_002.png,Cardiomegaly|Effusion,2,1,58,M,PA,2500,2048,0.168,0.168, +00000002_000.png,No Finding,0,2,80,M,PA,2500,2048,0.171,0.171, +00000003_001.png,Hernia,0,3,74,F,PA,2500,2048,0.168,0.168, +00000003_002.png,Hernia,1,3,75,F,PA,2048,2500,0.168,0.168, +00000003_003.png,Hernia|Infiltration,2,3,76,F,PA,2698,2991,0.14300000000000002,0.14300000000000002, +00000003_004.png,Hernia,3,3,77,F,PA,2500,2048,0.168,0.168, +00000003_005.png,Hernia,4,3,78,F,PA,2686,2991,0.14300000000000002,0.14300000000000002, +00000003_006.png,Hernia,5,3,79,F,PA,2992,2991,0.14300000000000002,0.14300000000000002, \ No newline at end of file diff --git a/test-resources/core/chestxray14/images/.gitkeep b/test-resources/core/chestxray14/images/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/tests/core/test_chestxray14.py b/tests/core/test_chestxray14.py index bf43db820..a1317813a 100644 --- a/tests/core/test_chestxray14.py +++ b/tests/core/test_chestxray14.py @@ -4,8 +4,7 @@ Author: Eric Schrock (ejs9@illinois.edu) """ -import os -import shutil +from pathlib import Path import tempfile import unittest @@ -19,38 +18,10 @@ class TestChestXray14Dataset(unittest.TestCase): @classmethod def setUpClass(cls): - if os.path.exists("test"): - shutil.rmtree("test") - os.makedirs("test/images") - - # Source: https://nihcc.app.box.com/v/ChestXray-NIHCC/file/219760887468 - lines = [ - "Image Index,Finding Labels,Follow-up #,Patient ID,Patient Age,Patient Sex,View Position,OriginalImage[Width,Height],OriginalImagePixelSpacing[x,y],", - "00000001_000.png,Cardiomegaly,0,1,57,M,PA,2682,2749,0.14300000000000002,0.14300000000000002,", - "00000001_001.png,Cardiomegaly|Emphysema,1,1,58,M,PA,2894,2729,0.14300000000000002,0.14300000000000002,", - "00000001_002.png,Cardiomegaly|Effusion,2,1,58,M,PA,2500,2048,0.168,0.168,", - "00000002_000.png,No Finding,0,2,80,M,PA,2500,2048,0.171,0.171,", - "00000003_001.png,Hernia,0,3,74,F,PA,2500,2048,0.168,0.168,", - "00000003_002.png,Hernia,1,3,75,F,PA,2048,2500,0.168,0.168,", - "00000003_003.png,Hernia|Infiltration,2,3,76,F,PA,2698,2991,0.14300000000000002,0.14300000000000002,", - "00000003_004.png,Hernia,3,3,77,F,PA,2500,2048,0.168,0.168,", - "00000003_005.png,Hernia,4,3,78,F,PA,2686,2991,0.14300000000000002,0.14300000000000002,", - "00000003_006.png,Hernia,5,3,79,F,PA,2992,2991,0.14300000000000002,0.14300000000000002,", - ] - - # Create mock images to test image loading - for line in lines[1:]: # Skip header row - name = line.split(',')[0] - img = Image.fromarray(np.random.randint(0, 256, (224, 224, 4), dtype=np.uint8), mode="RGBA") - img.save(os.path.join("test/images", name)) - - # Save image labels to file - with open("test/Data_Entry_2017_v2020.csv", 'w') as f: - f.write("\n".join(lines)) - + cls.root = Path(__file__).parent.parent.parent / "test-resources" / "core" / "chestxray14" + cls.generate_fake_images() cls.cache_dir = tempfile.TemporaryDirectory() - - cls.dataset = ChestXray14Dataset(root="./test", cache_dir=cls.cache_dir.name) + cls.dataset = ChestXray14Dataset(cls.root, cache_dir=cls.cache_dir.name) cls.samples_cardiomegaly = cls.dataset.set_task(ChestXray14BinaryClassification(disease="cardiomegaly")) cls.samples_hernia = cls.dataset.set_task(ChestXray14BinaryClassification(disease="hernia")) @@ -62,8 +33,23 @@ def tearDownClass(cls): cls.samples_hernia.close() cls.samples_multilabel.close() - if os.path.exists("test"): - shutil.rmtree("test") + Path(cls.dataset.root / "chestxray14-metadata-pyhealth.csv").unlink() + cls.delete_fake_images() + + @classmethod + def generate_fake_images(cls): + with open(Path(cls.root / "Data_Entry_2017_v2020.csv"), 'r') as f: + lines = f.readlines() + + for line in lines[1:]: # Skip header row + name = line.split(',')[0] + img = Image.fromarray(np.random.randint(0, 256, (224, 224, 4), dtype=np.uint8)) + img.save(Path(cls.root / "images" / name)) + + @classmethod + def delete_fake_images(cls): + for png in Path(cls.root / "images").glob("*.png"): + png.unlink() def test_stats(self): self.dataset.stats()