Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
images/
*.egg-info
.vscode
.env
Expand Down
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,10 @@ This module exposes a single function `download` which takes the same arguments
* **parquet** loads the urls and optional caption as a parquet
* **url_col** the name of the url column for parquet and csv (default *url*)
* **caption_col** the name of the caption column for parquet and csv (default *None*)
* **bbox_col** the name of the bounding box column. Bounding boxes are assumed to have format ```[x_min, y_min, x_max, y_max]```, with all elements being floats in *[0,1]* (relative to the size of the image). If *None*, then no bounding box blurring is performed (default *None*)
* **bbox_col** the name of the bounding box column. Bounding boxes are assumed to have format ```[x_min, y_min, x_max, y_max]```, with all elements being floats in *[0,1]* (relative to the size of the image). If *None*, then no bounding box operation is performed (default *None*)
* **bbox_operation** the operation to perform based on the bounding box metadata, will be used when bbox_col is specified (default *blur*)
* **blur**
* **crop**
* **number_sample_per_shard** the number of sample that will be downloaded in one shard (default *10000*)
* **extract_exif** if true, extract the exif information of the images and save it to the metadata (default *True*)
* **save_additional_columns** list of additional columns to take from the csv/parquet files and save in metadata files (default *None*)
Expand Down
53 changes: 48 additions & 5 deletions img2dataset/blurrer.py → img2dataset/bbox_processors.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,42 @@
"""blurrer module to blur parts of the image"""
"""Bounding box processing"""
from abc import ABC, abstractmethod
from enum import Enum

import numpy as np

import albumentations as A


class BoundingBoxBlurrer:
class BBoxOperation(Enum):
blur = 0 # pylint: disable=invalid-name
crop = 1 # pylint: disable=invalid-name


class BBoxProcessor(ABC):
"""
Abstract class for Bounding box processing
"""

def __init__(self, operation):
self.operation = operation

@abstractmethod
def __call__(self):
pass

def get_operation(self) -> BBoxOperation:
return self.operation


class BlurProcessor(BBoxProcessor):
"""blur images based on a bounding box.

The bounding box used is assumed to have format [x_min, y_min, x_max, y_max]
(with elements being floats in [0,1], relative to the original shape of the
image).
"""

def __init__(self) -> None:
pass
def __init__(self, *args, **kwargs):
super().__init__(BBoxOperation.blur, *args, **kwargs)

def __call__(self, img, bbox_list):
"""Apply blurring to bboxes of an image.
Expand Down Expand Up @@ -78,3 +100,24 @@ def __call__(self, img, bbox_list):
result = (result * 255.0).astype(np.uint8)

return result


class CropProcessor(BBoxProcessor):
"""Crop images based on a bounding box.

The bounding box used is assumed to have format [x_min, y_min, x_max, y_max]
(with elements being floats in [0,1], relative to the original shape of the
image).
"""

def __init__(self, *args, **kwargs):
super().__init__(BBoxOperation.crop, *args, **kwargs)

def __call__(self, img, bbox_list):
height, width = img.shape[0], img.shape[1]
x_min = int(bbox_list[0] * width)
y_min = int(bbox_list[1] * height)
x_max = int(bbox_list[2] * width)
y_max = int(bbox_list[3] * height)
img = img[y_min:y_max, x_min:x_max]
return img
6 changes: 3 additions & 3 deletions img2dataset/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(
retries,
user_agent_token,
disallowed_header_directives,
blurring_bbox_col=None,
bbox_col=None,
) -> None:
self.sample_writer_class = sample_writer_class
self.resizer = resizer
Expand All @@ -118,7 +118,7 @@ def __init__(
if disallowed_header_directives is None
else {directive.strip().lower() for directive in disallowed_header_directives}
)
self.blurring_bbox_col = blurring_bbox_col
self.bbox_col = bbox_col

def __call__(
self,
Expand Down Expand Up @@ -176,7 +176,7 @@ def download_shard(
hash_indice = (
self.column_list.index(self.verify_hash_type) if self.verify_hash_type in self.column_list else None
)
bbox_indice = self.column_list.index(self.blurring_bbox_col) if self.blurring_bbox_col is not None else None
bbox_indice = self.column_list.index(self.bbox_col) if self.bbox_col is not None else None
key_url_list = [(key, x[url_indice]) for key, x in shard_to_dl]

# this prevents an accumulation of more than twice the number of threads in sample ready to resize
Expand Down
18 changes: 13 additions & 5 deletions img2dataset/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
from .logger import LoggerProcess
from .resizer import Resizer
from .blurrer import BoundingBoxBlurrer
from .bbox_processors import BBoxProcessor, BlurProcessor, CropProcessor
from .writer import (
WebDatasetSampleWriter,
FilesSampleWriter,
Expand Down Expand Up @@ -83,6 +83,7 @@ def download(
url_col: str = "url",
caption_col: Optional[str] = None,
bbox_col: Optional[str] = None,
bbox_operation: Optional[str] = "blur",
thread_count: int = 256,
number_sample_per_shard: int = 10000,
extract_exif: bool = True,
Expand Down Expand Up @@ -198,10 +199,17 @@ def signal_handler(signal_arg, frame): # pylint: disable=unused-argument
else:
raise ValueError(f"Invalid output format {output_format}")

bbox_processor: Optional[BBoxProcessor]
if bbox_col is not None:
blurrer = BoundingBoxBlurrer()
if bbox_operation is None:
bbox_operation = "blur"

if bbox_operation.lower() == "blur":
bbox_processor = BlurProcessor()
elif bbox_operation.lower() == "crop":
bbox_processor = CropProcessor()
else:
blurrer = None
bbox_processor = None

resizer = Resizer(
image_size=image_size,
Expand All @@ -216,7 +224,7 @@ def signal_handler(signal_arg, frame): # pylint: disable=unused-argument
min_image_size=min_image_size,
max_image_area=max_image_area,
max_aspect_ratio=max_aspect_ratio,
blurrer=blurrer,
bbox_processor=bbox_processor,
)

downloader = Downloader(
Expand All @@ -236,7 +244,7 @@ def signal_handler(signal_arg, frame): # pylint: disable=unused-argument
retries=retries,
user_agent_token=user_agent_token,
disallowed_header_directives=disallowed_header_directives,
blurring_bbox_col=bbox_col,
bbox_col=bbox_col,
)

print("Starting the downloading of this file")
Expand Down
36 changes: 25 additions & 11 deletions img2dataset/resizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from enum import Enum
import imghdr
import os
from .bbox_processors import BBoxOperation

_INTER_STR_TO_CV2 = dict(
nearest=cv2.INTER_NEAREST,
Expand Down Expand Up @@ -93,7 +94,7 @@ def __init__(
min_image_size=0,
max_image_area=float("inf"),
max_aspect_ratio=float("inf"),
blurrer=None,
bbox_processor=None,
):
if encode_format not in ["jpg", "png", "webp"]:
raise ValueError(f"Invalid encode format {encode_format}")
Expand Down Expand Up @@ -132,9 +133,9 @@ def __init__(
self.min_image_size = min_image_size
self.max_image_area = max_image_area
self.max_aspect_ratio = max_aspect_ratio
self.blurrer = blurrer
self.bbox_processor = bbox_processor

def __call__(self, img_stream, blurring_bbox_list=None):
def __call__(self, img_stream, bbox_list=None):
"""
input: an image stream, optionally a list of bounding boxes to blur.
output: img_str, width, height, original_width, original_height, err
Expand Down Expand Up @@ -168,20 +169,25 @@ def __call__(self, img_stream, blurring_bbox_list=None):
return None, None, None, None, None, "aspect ratio too large"

# check if resizer was defined during init if needed
if blurring_bbox_list is not None and self.blurrer is None:
return None, None, None, None, None, "blurrer not defined"
if bbox_list is not None and self.bbox_processor is None:
return None, None, None, None, None, "bbox processor is not defined"

# Flag to check if blurring is still needed.
maybe_blur_still_needed = True
if self.bbox_processor is not None and self.bbox_processor.get_operation() is BBoxOperation.crop:
img = self.bbox_processor(img=img, bbox_list=bbox_list)

# resizing in following conditions
if self.resize_mode in (ResizeMode.keep_ratio, ResizeMode.center_crop):
downscale = min(original_width, original_height) > self.image_size
if not self.resize_only_if_bigger or downscale:
interpolation = self.downscale_interpolation if downscale else self.upscale_interpolation
img = A.smallest_max_size(img, self.image_size, interpolation=interpolation)
if blurring_bbox_list is not None and self.blurrer is not None:
img = self.blurrer(img=img, bbox_list=blurring_bbox_list)
if (
self.bbox_processor is not None
and self.bbox_processor.get_operation() is BBoxOperation.blur
):
img = self.bbox_processor(img=img, bbox_list=bbox_list)
if self.resize_mode == ResizeMode.center_crop:
img = A.center_crop(img, self.image_size, self.image_size)
encode_needed = True
Expand All @@ -191,8 +197,11 @@ def __call__(self, img_stream, blurring_bbox_list=None):
if not self.resize_only_if_bigger or downscale:
interpolation = self.downscale_interpolation if downscale else self.upscale_interpolation
img = A.longest_max_size(img, self.image_size, interpolation=interpolation)
if blurring_bbox_list is not None and self.blurrer is not None:
img = self.blurrer(img=img, bbox_list=blurring_bbox_list)
if (
self.bbox_processor is not None
and self.bbox_processor.get_operation() is BBoxOperation.blur
):
img = self.bbox_processor(img=img, bbox_list=bbox_list)
if self.resize_mode == ResizeMode.border:
img = A.pad(
img,
Expand All @@ -205,8 +214,13 @@ def __call__(self, img_stream, blurring_bbox_list=None):
maybe_blur_still_needed = False

# blur parts of the image if needed
if maybe_blur_still_needed and blurring_bbox_list is not None and self.blurrer is not None:
img = self.blurrer(img=img, bbox_list=blurring_bbox_list)
if (
maybe_blur_still_needed
and bbox_list is not None
and self.bbox_processor is not None
and self.bbox_processor.get_operation() is BBoxOperation.blur
):
img = self.bbox_processor(img=img, bbox_list=bbox_list)

height, width = img.shape[:2]
if encode_needed:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_blurrer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Tests for the bounding box blurring module."""

from img2dataset.blurrer import BoundingBoxBlurrer
from img2dataset.bbox_processors import BlurProcessor
import os
import pytest
import cv2
Expand All @@ -15,7 +15,7 @@ def test_blurrer():
blur_image_path = os.path.join(test_folder, "blurred.png")
bbox_path = os.path.join(test_folder, "bbox.npy")

blurrer = BoundingBoxBlurrer()
blurrer = BlurProcessor()
orig_image = cv2.imread(orig_image_path)
blur_image = cv2.imread(blur_image_path)
with open(bbox_path, "rb") as f:
Expand Down