Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,426 changes: 0 additions & 1,426 deletions examples/chestxray14_binary_classification.ipynb

This file was deleted.

1,266 changes: 0 additions & 1,266 deletions examples/chestxray14_multilabel_classification.ipynb

This file was deleted.

42 changes: 42 additions & 0 deletions examples/cxr/chestxray14_binary_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import tempfile

from pyhealth.datasets import ChestXray14Dataset, get_dataloader, split_by_sample
from pyhealth.models import CNN
from pyhealth.tasks import ChestXray14BinaryClassification
from pyhealth.trainer import Trainer

# Since PyHealth uses multiprocessing, it is best practice to use a main guard.
if __name__ == '__main__':
# Use tempfile to automate cleanup
dataset_dir = tempfile.TemporaryDirectory()
cache_dir = tempfile.TemporaryDirectory()

dataset = ChestXray14Dataset(
root=dataset_dir.name,
cache_dir=cache_dir.name,
download=True,
partial=True,
)
dataset.stats()

task = ChestXray14BinaryClassification(disease="infiltration")
samples = dataset.set_task(task)

train_dataset, val_dataset, test_dataset = split_by_sample(samples, [0.7, 0.1, 0.2])

train_loader = get_dataloader(train_dataset, batch_size=16, shuffle=True)
val_loader = get_dataloader(val_dataset, batch_size=16, shuffle=False)
test_loader = get_dataloader(test_dataset, batch_size=16, shuffle=False)

model = CNN(dataset=samples)

trainer = Trainer(model=model)
trainer.train(
train_dataloader=train_loader,
val_dataloader=val_loader,
epochs=1,
)

trainer.evaluate(test_loader)

samples.close()
42 changes: 42 additions & 0 deletions examples/cxr/chestxray14_multilabel_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import tempfile

from pyhealth.datasets import ChestXray14Dataset, get_dataloader, split_by_sample
from pyhealth.models import CNN
from pyhealth.trainer import Trainer

# Since PyHealth uses multiprocessing, it is best practice to use a main guard.
if __name__ == '__main__':
# Use tempfile to automate cleanup
dataset_dir = tempfile.TemporaryDirectory()
cache_dir = tempfile.TemporaryDirectory()

dataset = ChestXray14Dataset(
root=dataset_dir.name,
cache_dir=cache_dir.name,
download=True,
partial=True,
)
dataset.stats()

samples = dataset.set_task()

train_dataset, val_dataset, test_dataset = split_by_sample(samples, [0.7, 0.1, 0.2])

train_loader = get_dataloader(train_dataset, batch_size=16, shuffle=True)
val_loader = get_dataloader(val_dataset, batch_size=16, shuffle=False)
test_loader = get_dataloader(test_dataset, batch_size=16, shuffle=False)

model = CNN(dataset=samples)

# Only measure accurancy because with the "partial" dataset it is likely that
# there are not positive samples of every label present in the validation and test sets
trainer = Trainer(model=model, metrics=["accuracy"])
trainer.train(
train_dataloader=train_loader,
val_dataloader=val_loader,
epochs=1,
)

trainer.evaluate(test_loader)

samples.close()
33 changes: 24 additions & 9 deletions pyhealth/datasets/chestxray14.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(self,
FileNotFoundError: If the dataset path does not contain the 'images' directory.
ValueError: If the dataset 'images' directory does not contain any PNG files.

Example:
Example::
>>> dataset = ChestXray14Dataset(root="./data")
"""
self._label_path: str = os.path.join(root, "Data_Entry_2017_v2020.csv")
Expand All @@ -98,7 +98,7 @@ def default_task(self) -> ChestXray14MultilabelClassification:
Returns:
ChestXray14MultilabelClassification: The default classification task.

Example:
Example::
>>> dataset = ChestXray14Dataset()
>>> task = dataset.default_task
"""
Expand All @@ -118,12 +118,20 @@ def set_task(self, *args, **kwargs):

return super().set_task(*args, **kwargs)

set_task.__doc__ = (
f"{set_task.__doc__}\n"
" Note:\n"
" If no image processor is provided, a default grayscale `ImageProcessor(mode='L')` is injected. "
"This is needed because the ChestX-ray14 dataset images do not all have the same number of channels, "
"causing the default PyHealth image processor to fail."
)

def _download(self, root: str, partial: bool) -> None:
"""Downloads and verifies the ChestX-ray14 dataset files.

This method performs the following steps:
1. Downloads the label CSV file from a Google Drive mirror.
2. Downloads compressed image archives from NIH Box links.
1. Downloads the label CSV file from the shared NIH Box folder.
2. Downloads compressed image archives from static NIH Box links.
3. Verifies the integrity of each downloaded file using its MD5 checksum.
4. Extracts the image archives to the dataset directory.
5. Removes the original compressed files after successful extraction.
Expand All @@ -138,11 +146,18 @@ def _download(self, root: str, partial: bool) -> None:
ValueError: If an image tar file contains an unsafe path.
ValueError: If an unexpected number of images are downloaded.
"""
# https://nihcc.app.box.com/v/ChestXray-NIHCC/file/219760887468 (mirrored to Google Drive)
# I couldn't figure out a way to download this file directly from box.com
response = requests.get('https://drive.google.com/uc?export=download&id=1mkOZNfYt-Px52b8CJZJANNbM3ULUVO3f')
with open(self._label_path, "wb") as file:
file.write(response.content)
response = requests.get(
url=(
"https://nihcc.app.box.com/index.php"
"?rm=box_download_shared_file"
"&vanity_name=ChestXray-NIHCC"
"&file_id=f_219760887468"
),
allow_redirects=True,
)

with open(self._label_path, "wb") as f:
f.write(response.content)

# https://nihcc.app.box.com/v/ChestXray-NIHCC/file/371647823217
links = [
Expand Down
11 changes: 11 additions & 0 deletions test-resources/core/chestxray14/Data_Entry_2017_v2020.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Image Index,Finding Labels,Follow-up #,Patient ID,Patient Age,Patient Sex,View Position,OriginalImage[Width,Height],OriginalImagePixelSpacing[x,y],
00000001_000.png,Cardiomegaly,0,1,57,M,PA,2682,2749,0.14300000000000002,0.14300000000000002,
00000001_001.png,Cardiomegaly|Emphysema,1,1,58,M,PA,2894,2729,0.14300000000000002,0.14300000000000002,
00000001_002.png,Cardiomegaly|Effusion,2,1,58,M,PA,2500,2048,0.168,0.168,
00000002_000.png,No Finding,0,2,80,M,PA,2500,2048,0.171,0.171,
00000003_001.png,Hernia,0,3,74,F,PA,2500,2048,0.168,0.168,
00000003_002.png,Hernia,1,3,75,F,PA,2048,2500,0.168,0.168,
00000003_003.png,Hernia|Infiltration,2,3,76,F,PA,2698,2991,0.14300000000000002,0.14300000000000002,
00000003_004.png,Hernia,3,3,77,F,PA,2500,2048,0.168,0.168,
00000003_005.png,Hernia,4,3,78,F,PA,2686,2991,0.14300000000000002,0.14300000000000002,
00000003_006.png,Hernia,5,3,79,F,PA,2992,2991,0.14300000000000002,0.14300000000000002,
Empty file.
56 changes: 21 additions & 35 deletions tests/core/test_chestxray14.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
Author:
Eric Schrock ([email protected])
"""
import os
import shutil
from pathlib import Path
import tempfile
import unittest

Expand All @@ -19,38 +18,10 @@
class TestChestXray14Dataset(unittest.TestCase):
@classmethod
def setUpClass(cls):
if os.path.exists("test"):
shutil.rmtree("test")
os.makedirs("test/images")

# Source: https://nihcc.app.box.com/v/ChestXray-NIHCC/file/219760887468
lines = [
"Image Index,Finding Labels,Follow-up #,Patient ID,Patient Age,Patient Sex,View Position,OriginalImage[Width,Height],OriginalImagePixelSpacing[x,y],",
"00000001_000.png,Cardiomegaly,0,1,57,M,PA,2682,2749,0.14300000000000002,0.14300000000000002,",
"00000001_001.png,Cardiomegaly|Emphysema,1,1,58,M,PA,2894,2729,0.14300000000000002,0.14300000000000002,",
"00000001_002.png,Cardiomegaly|Effusion,2,1,58,M,PA,2500,2048,0.168,0.168,",
"00000002_000.png,No Finding,0,2,80,M,PA,2500,2048,0.171,0.171,",
"00000003_001.png,Hernia,0,3,74,F,PA,2500,2048,0.168,0.168,",
"00000003_002.png,Hernia,1,3,75,F,PA,2048,2500,0.168,0.168,",
"00000003_003.png,Hernia|Infiltration,2,3,76,F,PA,2698,2991,0.14300000000000002,0.14300000000000002,",
"00000003_004.png,Hernia,3,3,77,F,PA,2500,2048,0.168,0.168,",
"00000003_005.png,Hernia,4,3,78,F,PA,2686,2991,0.14300000000000002,0.14300000000000002,",
"00000003_006.png,Hernia,5,3,79,F,PA,2992,2991,0.14300000000000002,0.14300000000000002,",
]

# Create mock images to test image loading
for line in lines[1:]: # Skip header row
name = line.split(',')[0]
img = Image.fromarray(np.random.randint(0, 256, (224, 224, 4), dtype=np.uint8), mode="RGBA")
img.save(os.path.join("test/images", name))

# Save image labels to file
with open("test/Data_Entry_2017_v2020.csv", 'w') as f:
f.write("\n".join(lines))

cls.root = Path(__file__).parent.parent.parent / "test-resources" / "core" / "chestxray14"
cls.generate_fake_images()
cls.cache_dir = tempfile.TemporaryDirectory()

cls.dataset = ChestXray14Dataset(root="./test", cache_dir=cls.cache_dir.name)
cls.dataset = ChestXray14Dataset(cls.root, cache_dir=cls.cache_dir.name)

cls.samples_cardiomegaly = cls.dataset.set_task(ChestXray14BinaryClassification(disease="cardiomegaly"))
cls.samples_hernia = cls.dataset.set_task(ChestXray14BinaryClassification(disease="hernia"))
Expand All @@ -62,8 +33,23 @@ def tearDownClass(cls):
cls.samples_hernia.close()
cls.samples_multilabel.close()

if os.path.exists("test"):
shutil.rmtree("test")
Path(cls.dataset.root / "chestxray14-metadata-pyhealth.csv").unlink()
cls.delete_fake_images()

@classmethod
def generate_fake_images(cls):
with open(Path(cls.root / "Data_Entry_2017_v2020.csv"), 'r') as f:
lines = f.readlines()

for line in lines[1:]: # Skip header row
name = line.split(',')[0]
img = Image.fromarray(np.random.randint(0, 256, (224, 224, 4), dtype=np.uint8))
img.save(Path(cls.root / "images" / name))

@classmethod
def delete_fake_images(cls):
for png in Path(cls.root / "images").glob("*.png"):
png.unlink()

def test_stats(self):
self.dataset.stats()
Expand Down