Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions test/test_update_PCAM.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import bz2
import contextlib
import csv
import io
import itertools
import json
import os
import pathlib
import pickle
import random
import re
import shutil
import string
import unittest
import xml.etree.ElementTree as ET
import zipfile
from typing import Callable, Union

import datasets_utils
import numpy as np
import PIL
import pytest
import torch
import torch.nn.functional as F
from common_utils import combinations_grid
from torchvision import datasets
from torchvision.io import decode_image
from torchvision.transforms import v2


class PCAMTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.PCAM

ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val", "test"))
REQUIRED_PACKAGES = ("h5py",)

def inject_fake_data(self, tmpdir: str, config):
base_folder = pathlib.Path(tmpdir) / "pcam"
base_folder.mkdir()

num_images = {"train": 2, "test": 3, "val": 4}[config["split"]]

images_file = datasets.PCAM._FILES[config["split"]]["images"][0]
with datasets_utils.lazy_importer.h5py.File(str(base_folder / images_file), "w") as f:
f["x"] = np.random.randint(0, 256, size=(num_images, 10, 10, 3), dtype=np.uint8)

targets_file = datasets.PCAM._FILES[config["split"]]["targets"][0]
with datasets_utils.lazy_importer.h5py.File(str(base_folder / targets_file), "w") as f:
f["y"] = np.random.randint(0, 2, size=(num_images, 1, 1, 1), dtype=np.uint8)

return num_images

if __name__ == "__main__":
unittest.main()
89 changes: 89 additions & 0 deletions test/test_update_PCAM_multiprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# test/test_update_PCAM_multiprocessing.py
import os
import socket
import tempfile
from contextlib import closing

import pytest
import torch
from torch.utils.data import DataLoader, distributed
from torchvision import datasets
from torchvision.transforms import v2


def _find_free_port() -> int:
"""Pick a free TCP port for DDP init."""
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", 0))
return s.getsockname()[1]


def _ddp_worker(rank: int, world_size: int, port: int, root: str, backend: str):
"""Single DDP worker."""
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(port)

torch.distributed.init_process_group(backend, rank=rank, world_size=world_size)

if backend == "nccl":
torch.cuda.set_device(rank)
device = torch.device(f"cuda:{rank}")
pin_memory = True
else:
device = torch.device("cpu")
pin_memory = False

# ---- dataset ----
ds = datasets.PCAM(root=root, split="train", download=True, transform=v2.ToTensor())
sampler = distributed.DistributedSampler(ds, num_replicas=world_size, rank=rank, shuffle=True)
loader = DataLoader(
ds,
batch_size=16,
sampler=sampler,
num_workers=2,
pin_memory=pin_memory,
persistent_workers=True,
)

# ---- iterate few batches ----
local_seen = 0
for i, (x, y) in enumerate(loader):
assert x.ndim == 4 and y.ndim == 1
if backend == "nccl":
x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
local_seen += x.size(0)
if i >= 3:
break

# ---- allreduce sanity check ----
t = torch.tensor([local_seen], dtype=torch.int64, device=device)
torch.distributed.all_reduce(t)
assert t.item() > 0

torch.distributed.destroy_process_group()


@pytest.mark.parametrize("world_size", [2])
@pytest.mark.parametrize("backend", ["gloo"]) # add "nccl" if you want GPU test too
def test_pcam_ddp(world_size, backend):
"""Smoke test PCAM with DDP + multiprocessing DataLoader."""
if backend == "nccl" and not torch.cuda.is_available():
pytest.skip("CUDA not available for NCCL backend")

with tempfile.TemporaryDirectory() as tmp:
root = os.path.join(tmp, "pcam_data")
os.makedirs(root, exist_ok=True)
port = _find_free_port()

# The simple spawn call you wanted
torch.multiprocessing.spawn(
_ddp_worker,
args=(world_size, port, root, backend), # passed to worker after rank
nprocs=world_size,
join=True,
)


if __name__ == "__main__":
# Allow running standalone: python test/test_update_PCAM_multiprocessing.py
pytest.main([__file__, "-vvv", "-k", "test_pcam_ddp"])
27 changes: 15 additions & 12 deletions torchvision/datasets/pcam.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,16 @@
from .utils import _decompress, download_file_from_google_drive, verify_str_arg
from .vision import VisionDataset


def _get_h5py():
"""Import h5py on demand with a clear error message."""
try:
import h5py
return h5py
except ImportError:
raise RuntimeError(
"h5py is not found. This dataset needs to have h5py installed: please run pip install h5py"
)

class PCAM(VisionDataset):
"""`PCAM Dataset <https://github.com/basveeling/pcam>`_.

Expand Down Expand Up @@ -78,14 +87,6 @@ def __init__(
target_transform: Optional[Callable] = None,
download: bool = False,
):
try:
import h5py

self.h5py = h5py
except ImportError:
raise RuntimeError(
"h5py is not found. This dataset needs to have h5py installed: please run pip install h5py"
)

self._split = verify_str_arg(split, "split", ("train", "test", "val"))

Expand All @@ -100,16 +101,18 @@ def __init__(

def __len__(self) -> int:
images_file = self._FILES[self._split]["images"][0]
with self.h5py.File(self._base_folder / images_file) as images_data:
h5py = _get_h5py()
with h5py.File(self._base_folder / images_file) as images_data:
return images_data["x"].shape[0]

def __getitem__(self, idx: int) -> tuple[Any, Any]:
images_file = self._FILES[self._split]["images"][0]
with self.h5py.File(self._base_folder / images_file) as images_data:
h5py = _get_h5py()
with h5py.File(self._base_folder / images_file) as images_data:
image = Image.fromarray(images_data["x"][idx]).convert("RGB")

targets_file = self._FILES[self._split]["targets"][0]
with self.h5py.File(self._base_folder / targets_file) as targets_data:
with h5py.File(self._base_folder / targets_file) as targets_data:
target = int(targets_data["y"][idx, 0, 0, 0]) # shape is [num_images, 1, 1, 1]

if self.transform:
Expand Down