Skip to content

Commit 9783019

Browse files
authored
Tidy up the ChestX-ray14 dataset for the PyHealth 2.0 release (#777)
* Update the chestxray14 unit tests to use the test-resources directory like the other unit tests (still generate fake images to avoid adding large files to the repo history) * Fix Pillow mode param deprecation warning * Port chestxray14 examples to python scripts * Update ChestXray14Dataset to download the labels and metadata CSV file directly from the NIH box share instead of from a mirror in a personal Google Drive * Add note to the set_task doc string explaining why a default grayscale image processor is needed for the chestxray14 dataset
1 parent 3d6e318 commit 9783019

File tree

8 files changed

+140
-2736
lines changed

8 files changed

+140
-2736
lines changed

examples/chestxray14_binary_classification.ipynb

Lines changed: 0 additions & 1426 deletions
This file was deleted.

examples/chestxray14_multilabel_classification.ipynb

Lines changed: 0 additions & 1266 deletions
This file was deleted.
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import tempfile
2+
3+
from pyhealth.datasets import ChestXray14Dataset, get_dataloader, split_by_sample
4+
from pyhealth.models import CNN
5+
from pyhealth.tasks import ChestXray14BinaryClassification
6+
from pyhealth.trainer import Trainer
7+
8+
# Since PyHealth uses multiprocessing, it is best practice to use a main guard.
9+
if __name__ == '__main__':
10+
# Use tempfile to automate cleanup
11+
dataset_dir = tempfile.TemporaryDirectory()
12+
cache_dir = tempfile.TemporaryDirectory()
13+
14+
dataset = ChestXray14Dataset(
15+
root=dataset_dir.name,
16+
cache_dir=cache_dir.name,
17+
download=True,
18+
partial=True,
19+
)
20+
dataset.stats()
21+
22+
task = ChestXray14BinaryClassification(disease="infiltration")
23+
samples = dataset.set_task(task)
24+
25+
train_dataset, val_dataset, test_dataset = split_by_sample(samples, [0.7, 0.1, 0.2])
26+
27+
train_loader = get_dataloader(train_dataset, batch_size=16, shuffle=True)
28+
val_loader = get_dataloader(val_dataset, batch_size=16, shuffle=False)
29+
test_loader = get_dataloader(test_dataset, batch_size=16, shuffle=False)
30+
31+
model = CNN(dataset=samples)
32+
33+
trainer = Trainer(model=model)
34+
trainer.train(
35+
train_dataloader=train_loader,
36+
val_dataloader=val_loader,
37+
epochs=1,
38+
)
39+
40+
trainer.evaluate(test_loader)
41+
42+
samples.close()
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import tempfile
2+
3+
from pyhealth.datasets import ChestXray14Dataset, get_dataloader, split_by_sample
4+
from pyhealth.models import CNN
5+
from pyhealth.trainer import Trainer
6+
7+
# Since PyHealth uses multiprocessing, it is best practice to use a main guard.
8+
if __name__ == '__main__':
9+
# Use tempfile to automate cleanup
10+
dataset_dir = tempfile.TemporaryDirectory()
11+
cache_dir = tempfile.TemporaryDirectory()
12+
13+
dataset = ChestXray14Dataset(
14+
root=dataset_dir.name,
15+
cache_dir=cache_dir.name,
16+
download=True,
17+
partial=True,
18+
)
19+
dataset.stats()
20+
21+
samples = dataset.set_task()
22+
23+
train_dataset, val_dataset, test_dataset = split_by_sample(samples, [0.7, 0.1, 0.2])
24+
25+
train_loader = get_dataloader(train_dataset, batch_size=16, shuffle=True)
26+
val_loader = get_dataloader(val_dataset, batch_size=16, shuffle=False)
27+
test_loader = get_dataloader(test_dataset, batch_size=16, shuffle=False)
28+
29+
model = CNN(dataset=samples)
30+
31+
# Only measure accurancy because with the "partial" dataset it is likely that
32+
# there are not positive samples of every label present in the validation and test sets
33+
trainer = Trainer(model=model, metrics=["accuracy"])
34+
trainer.train(
35+
train_dataloader=train_loader,
36+
val_dataloader=val_loader,
37+
epochs=1,
38+
)
39+
40+
trainer.evaluate(test_loader)
41+
42+
samples.close()

pyhealth/datasets/chestxray14.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def __init__(self,
7171
FileNotFoundError: If the dataset path does not contain the 'images' directory.
7272
ValueError: If the dataset 'images' directory does not contain any PNG files.
7373
74-
Example:
74+
Example::
7575
>>> dataset = ChestXray14Dataset(root="./data")
7676
"""
7777
self._label_path: str = os.path.join(root, "Data_Entry_2017_v2020.csv")
@@ -98,7 +98,7 @@ def default_task(self) -> ChestXray14MultilabelClassification:
9898
Returns:
9999
ChestXray14MultilabelClassification: The default classification task.
100100
101-
Example:
101+
Example::
102102
>>> dataset = ChestXray14Dataset()
103103
>>> task = dataset.default_task
104104
"""
@@ -118,12 +118,20 @@ def set_task(self, *args, **kwargs):
118118

119119
return super().set_task(*args, **kwargs)
120120

121+
set_task.__doc__ = (
122+
f"{set_task.__doc__}\n"
123+
" Note:\n"
124+
" If no image processor is provided, a default grayscale `ImageProcessor(mode='L')` is injected. "
125+
"This is needed because the ChestX-ray14 dataset images do not all have the same number of channels, "
126+
"causing the default PyHealth image processor to fail."
127+
)
128+
121129
def _download(self, root: str, partial: bool) -> None:
122130
"""Downloads and verifies the ChestX-ray14 dataset files.
123131
124132
This method performs the following steps:
125-
1. Downloads the label CSV file from a Google Drive mirror.
126-
2. Downloads compressed image archives from NIH Box links.
133+
1. Downloads the label CSV file from the shared NIH Box folder.
134+
2. Downloads compressed image archives from static NIH Box links.
127135
3. Verifies the integrity of each downloaded file using its MD5 checksum.
128136
4. Extracts the image archives to the dataset directory.
129137
5. Removes the original compressed files after successful extraction.
@@ -138,11 +146,18 @@ def _download(self, root: str, partial: bool) -> None:
138146
ValueError: If an image tar file contains an unsafe path.
139147
ValueError: If an unexpected number of images are downloaded.
140148
"""
141-
# https://nihcc.app.box.com/v/ChestXray-NIHCC/file/219760887468 (mirrored to Google Drive)
142-
# I couldn't figure out a way to download this file directly from box.com
143-
response = requests.get('https://drive.google.com/uc?export=download&id=1mkOZNfYt-Px52b8CJZJANNbM3ULUVO3f')
144-
with open(self._label_path, "wb") as file:
145-
file.write(response.content)
149+
response = requests.get(
150+
url=(
151+
"https://nihcc.app.box.com/index.php"
152+
"?rm=box_download_shared_file"
153+
"&vanity_name=ChestXray-NIHCC"
154+
"&file_id=f_219760887468"
155+
),
156+
allow_redirects=True,
157+
)
158+
159+
with open(self._label_path, "wb") as f:
160+
f.write(response.content)
146161

147162
# https://nihcc.app.box.com/v/ChestXray-NIHCC/file/371647823217
148163
links = [
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
Image Index,Finding Labels,Follow-up #,Patient ID,Patient Age,Patient Sex,View Position,OriginalImage[Width,Height],OriginalImagePixelSpacing[x,y],
2+
00000001_000.png,Cardiomegaly,0,1,57,M,PA,2682,2749,0.14300000000000002,0.14300000000000002,
3+
00000001_001.png,Cardiomegaly|Emphysema,1,1,58,M,PA,2894,2729,0.14300000000000002,0.14300000000000002,
4+
00000001_002.png,Cardiomegaly|Effusion,2,1,58,M,PA,2500,2048,0.168,0.168,
5+
00000002_000.png,No Finding,0,2,80,M,PA,2500,2048,0.171,0.171,
6+
00000003_001.png,Hernia,0,3,74,F,PA,2500,2048,0.168,0.168,
7+
00000003_002.png,Hernia,1,3,75,F,PA,2048,2500,0.168,0.168,
8+
00000003_003.png,Hernia|Infiltration,2,3,76,F,PA,2698,2991,0.14300000000000002,0.14300000000000002,
9+
00000003_004.png,Hernia,3,3,77,F,PA,2500,2048,0.168,0.168,
10+
00000003_005.png,Hernia,4,3,78,F,PA,2686,2991,0.14300000000000002,0.14300000000000002,
11+
00000003_006.png,Hernia,5,3,79,F,PA,2992,2991,0.14300000000000002,0.14300000000000002,

test-resources/core/chestxray14/images/.gitkeep

Whitespace-only changes.

tests/core/test_chestxray14.py

Lines changed: 21 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
Author:
55
Eric Schrock (ejs9@illinois.edu)
66
"""
7-
import os
8-
import shutil
7+
from pathlib import Path
98
import tempfile
109
import unittest
1110

@@ -19,38 +18,10 @@
1918
class TestChestXray14Dataset(unittest.TestCase):
2019
@classmethod
2120
def setUpClass(cls):
22-
if os.path.exists("test"):
23-
shutil.rmtree("test")
24-
os.makedirs("test/images")
25-
26-
# Source: https://nihcc.app.box.com/v/ChestXray-NIHCC/file/219760887468
27-
lines = [
28-
"Image Index,Finding Labels,Follow-up #,Patient ID,Patient Age,Patient Sex,View Position,OriginalImage[Width,Height],OriginalImagePixelSpacing[x,y],",
29-
"00000001_000.png,Cardiomegaly,0,1,57,M,PA,2682,2749,0.14300000000000002,0.14300000000000002,",
30-
"00000001_001.png,Cardiomegaly|Emphysema,1,1,58,M,PA,2894,2729,0.14300000000000002,0.14300000000000002,",
31-
"00000001_002.png,Cardiomegaly|Effusion,2,1,58,M,PA,2500,2048,0.168,0.168,",
32-
"00000002_000.png,No Finding,0,2,80,M,PA,2500,2048,0.171,0.171,",
33-
"00000003_001.png,Hernia,0,3,74,F,PA,2500,2048,0.168,0.168,",
34-
"00000003_002.png,Hernia,1,3,75,F,PA,2048,2500,0.168,0.168,",
35-
"00000003_003.png,Hernia|Infiltration,2,3,76,F,PA,2698,2991,0.14300000000000002,0.14300000000000002,",
36-
"00000003_004.png,Hernia,3,3,77,F,PA,2500,2048,0.168,0.168,",
37-
"00000003_005.png,Hernia,4,3,78,F,PA,2686,2991,0.14300000000000002,0.14300000000000002,",
38-
"00000003_006.png,Hernia,5,3,79,F,PA,2992,2991,0.14300000000000002,0.14300000000000002,",
39-
]
40-
41-
# Create mock images to test image loading
42-
for line in lines[1:]: # Skip header row
43-
name = line.split(',')[0]
44-
img = Image.fromarray(np.random.randint(0, 256, (224, 224, 4), dtype=np.uint8), mode="RGBA")
45-
img.save(os.path.join("test/images", name))
46-
47-
# Save image labels to file
48-
with open("test/Data_Entry_2017_v2020.csv", 'w') as f:
49-
f.write("\n".join(lines))
50-
21+
cls.root = Path(__file__).parent.parent.parent / "test-resources" / "core" / "chestxray14"
22+
cls.generate_fake_images()
5123
cls.cache_dir = tempfile.TemporaryDirectory()
52-
53-
cls.dataset = ChestXray14Dataset(root="./test", cache_dir=cls.cache_dir.name)
24+
cls.dataset = ChestXray14Dataset(cls.root, cache_dir=cls.cache_dir.name)
5425

5526
cls.samples_cardiomegaly = cls.dataset.set_task(ChestXray14BinaryClassification(disease="cardiomegaly"))
5627
cls.samples_hernia = cls.dataset.set_task(ChestXray14BinaryClassification(disease="hernia"))
@@ -62,8 +33,23 @@ def tearDownClass(cls):
6233
cls.samples_hernia.close()
6334
cls.samples_multilabel.close()
6435

65-
if os.path.exists("test"):
66-
shutil.rmtree("test")
36+
Path(cls.dataset.root / "chestxray14-metadata-pyhealth.csv").unlink()
37+
cls.delete_fake_images()
38+
39+
@classmethod
40+
def generate_fake_images(cls):
41+
with open(Path(cls.root / "Data_Entry_2017_v2020.csv"), 'r') as f:
42+
lines = f.readlines()
43+
44+
for line in lines[1:]: # Skip header row
45+
name = line.split(',')[0]
46+
img = Image.fromarray(np.random.randint(0, 256, (224, 224, 4), dtype=np.uint8))
47+
img.save(Path(cls.root / "images" / name))
48+
49+
@classmethod
50+
def delete_fake_images(cls):
51+
for png in Path(cls.root / "images").glob("*.png"):
52+
png.unlink()
6753

6854
def test_stats(self):
6955
self.dataset.stats()

0 commit comments

Comments
 (0)