diff --git a/RELEASES.rst b/RELEASES.rst index 494fbd50..a76a30d1 100644 --- a/RELEASES.rst +++ b/RELEASES.rst @@ -15,3 +15,4 @@ Version 0.1 - RankMe, LiDAR metrics to monitor training. - Examples of extracting run data from WandB and utilizing it to create figures. - Fixed a bug in the logging functionality. +- Library for injecting spurious tokens into HuggingFace datasets. diff --git a/examples/sample_spurious_injection_execution.py b/examples/sample_spurious_injection_execution.py new file mode 100644 index 00000000..8aab3743 --- /dev/null +++ b/examples/sample_spurious_injection_execution.py @@ -0,0 +1,282 @@ +"""Demonstration of the spurious_corr library capabilities.""" + +from stable_pretraining.data.spurious_corr.modifiers import ( + ItemInjection, + HTMLInjection, + CompositeModifier, +) +from stable_pretraining.data.spurious_corr.generators import SpuriousDateGenerator +from stable_pretraining.data.spurious_corr.utils import ( + pretty_print, + pretty_print_dataset, + highlight_from_file, + highlight_from_list, + highlight_html, + highlight_dates, +) +from stable_pretraining.data.spurious_corr.transform import spurious_transform +from datasets import load_dataset + + +def print_section(title): + """Print a formatted section header.""" + print("\n" + "=" * 60) + print(f" {title}") + print("=" * 60) + + +def example_1_basic_date_injection(): + """Example 1: Basic date injection at different locations.""" + print_section("Example 1: Date Injection with SpuriousDateGenerator") + text = "Machine learning models require careful evaluation and testing procedures." + print(f"\nOriginal text: '{text}'\n") + print("-" * 50) + + # Create date generator + date_gen = SpuriousDateGenerator(year_range=(1900, 2100), seed=42) + + # Example 1a: Inject at beginning + modifier_start = ItemInjection.from_function( + injection_func=date_gen, location="beginning", token_proportion=0.4, seed=42 + ) + + modified_text, _ = modifier_start(text, 1) + print("1a. Date injection at BEGINNING:") + pretty_print(modified_text, highlight_dates) + print("-" * 50) + + # Example 1b: Inject at end + modifier_end = ItemInjection.from_function( + injection_func=date_gen, location="end", token_proportion=0.4, seed=43 + ) + + modified_text, _ = modifier_end(text, 1) + print("1b. Date injection at END:") + pretty_print(modified_text, highlight_dates) + print("-" * 50) + + # Example 1c: Inject at random locations + modifier_random = ItemInjection.from_function( + injection_func=date_gen, location="random", token_proportion=0.4, seed=44 + ) + + modified_text, _ = modifier_random(text, 1) + print("1c. Date injection at RANDOM positions:") + pretty_print(modified_text, highlight_dates) + print("-" * 50) + + +def example_2_file_based_injection(): + """Example 2: Inject tokens from files (countries, colors).""" + print_section("Example 2: File-Based Token Injection") + + # Example 2a: Country injection + country_modifier = ItemInjection.from_file( + file_path="examples/data/countries.txt", + location="random", + token_proportion=0.3, + seed=42, + ) + + country_highlighter = highlight_from_file("examples/data/countries.txt") + text = "International trade agreements benefit global economic stability." + modified_text, _ = country_modifier(text, 1) + + print("2a. Country injection:") + pretty_print(modified_text, country_highlighter) + print("-" * 50) + + # Example 2b: Color injection + color_modifier = ItemInjection.from_file( + file_path="examples/data/colors.txt", + location="random", + token_proportion=1, + seed=42, + ) + + color_highlighter = highlight_from_file("examples/data/colors.txt") + text = "The sunset painted the sky beautifully." + modified_text, _ = color_modifier(text, 1) + + print("2b. Color injection:") + pretty_print(modified_text, color_highlighter) + print("-" * 50) + + # Example 2c: Custom word list + custom_modifier = ItemInjection.from_list( + items=["URGENT", "BREAKING", "EXCLUSIVE", "ALERT"], + location="random", + token_proportion=1, + seed=42, + ) + + custom_highlighter = highlight_from_list( + ["URGENT", "BREAKING", "EXCLUSIVE", "ALERT"] + ) + text = "Weather forecast predicts rain tomorrow." + modified_text, _ = custom_modifier(text, 1) + + print("2c. Custom urgent words:") + pretty_print(modified_text, custom_highlighter) + print("-" * 50) + + +def example_3_html_injection(): + """Example 3: HTML tag injection with different strategies.""" + print_section("Example 3: HTML Tag Injection") + html_highlighter = highlight_html("examples/data/html_tags.txt") + text = "This is an important announcement for all users." + + # Example 3a: Single HTML tag at beginning + begin_modifier = HTMLInjection.from_file( + file_path="examples/data/html_tags.txt", location="beginning", seed=42 + ) + + modified_text, _ = begin_modifier(text, 1) + print("3a. Beginning single HTML tag injection:") + pretty_print(modified_text, html_highlighter) + print("-" * 50) + + # Example 3b: Single HTML tag at random location + random_modifier = HTMLInjection.from_file( + file_path="examples/data/html_tags.txt", location="random", seed=43 + ) + + modified_text, _ = random_modifier(text, 1) + print("3b. Random single HTML tag injection:") + pretty_print(modified_text, html_highlighter) + print("-" * 50) + + # Example 3c: Single HTML tag at end + end_modifier = HTMLInjection.from_file( + file_path="examples/data/html_tags.txt", location="end", seed=44 + ) + + modified_text, _ = end_modifier(text, 1) + print("3c. End single HTML tag injection:") + pretty_print(modified_text, html_highlighter) + print("-" * 50) + + # Example 3d: Multiple HTML tags at random locations + multi_random_modifier = HTMLInjection.from_file( + file_path="examples/data/html_tags.txt", + location="random", + token_proportion=0.5, + seed=45, + ) + + modified_text, _ = multi_random_modifier(text, 1) + print("3d. Multiple random HTML tag injection:") + pretty_print(modified_text, html_highlighter) + print("-" * 50) + + +def example_4_multiple_injections(): + """Example 4: Multiple different injection types combined.""" + print_section("Example 4: Multiple Injection Types Combined") + + # Date at beginning + date_modifier = ItemInjection.from_function( + SpuriousDateGenerator(year_range=(2020, 2024), seed=42), + location="beginning", + token_proportion=0, + ) + + # Country in middle + country_modifier = ItemInjection.from_file( + file_path="examples/data/countries.txt", + location="random", + token_proportion=0, + seed=43, + ) + + # Color at end + color_modifier = ItemInjection.from_file( + file_path="examples/data/colors.txt", + location="end", + token_proportion=0, + seed=44, + ) + + # Combine all + multi_modifier = CompositeModifier( + [date_modifier, country_modifier, color_modifier] + ) + + text = "Economic analysis shows promising trends in renewable energy sectors." + modified_text, _ = multi_modifier(text, 1) + print("4. Multiple injection types:") + print(modified_text) + print("-" * 50) + + +def example_5_token_density_comparison(): + """Example 5: Compare different token proportion levels.""" + print_section("Example 5: Token Proportion Comparison") + + text = "Artificial intelligence and machine learning technologies are transforming industries." + highlighter = highlight_dates + + token_proportions = [0, 0.3, 0.5, 0.8, 1.0] # 0 injects a single token + + for density in token_proportions: + modifier = ItemInjection.from_function( + SpuriousDateGenerator(year_range=(2020, 2024), seed=42), + location="random", + token_proportion=density, + seed=42, + ) + + modified_text, _ = modifier(text, 1) + print(f"\nToken proportion {density}:") + pretty_print(modified_text, highlighter) + + +def example_6_dataset_simulation(): + """Example 6: Simulate dataset-level spurious correlations.""" + print_section( + "Example 6: Dataset-Level Spurious Correlation Simulation using spurious_transform" + ) + + # Load IMDB dataset + dataset = load_dataset("imdb", split="train") # Load full training dataset + + # Create date modifier + date_modifier = ItemInjection.from_function( + SpuriousDateGenerator(year_range=(2020, 2024), seed=42), + location="random", + token_proportion=0.1, + seed=42, + ) + + print("Simulating spurious correlation: Add dates to positive reviews only\n") + + # Apply spurious transformation + modified_dataset = spurious_transform( + label_to_modify=1, # Target positive reviews + dataset=dataset, + modifier=date_modifier, + text_proportion=1.0, # Apply to all positive reviews + seed=42, + ) + + # Print examples using pretty_print_dataset + print("Positive reviews (with injected dates):") + pretty_print_dataset(modified_dataset, n=3, highlight_func=highlight_dates, label=1) + + print("\nNegative reviews (original):") + pretty_print_dataset(modified_dataset, n=3, highlight_func=highlight_dates, label=0) + + +def main(): + """Run all examples demonstrating library capabilities.""" + example_1_basic_date_injection() + example_2_file_based_injection() + example_3_html_injection() + example_4_multiple_injections() + example_5_token_density_comparison() + example_6_dataset_simulation() + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 828abdfd..871925b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,6 +69,7 @@ datasets = [ "datasets", # HuggingFace datasets "pyarrow==20.0.0", # Required for datasets compatibility "minari[hdf5]>=0.5.3", # Reinforcement learning datasets + "termcolor", # Visualizing spurious correlations ] # Additional utilities diff --git a/stable_pretraining/data/pprint.py b/stable_pretraining/data/pprint.py new file mode 100644 index 00000000..40c74d76 --- /dev/null +++ b/stable_pretraining/data/pprint.py @@ -0,0 +1,124 @@ +"""pprint.py. + +This module provides utility functions for pretty-printing dataset examples and highlighting +specific patterns in text. These functions are useful for debugging and visualizing the modifications +applied to the dataset. +""" + +import re +from termcolor import colored + + +class TextHighlight: + """Class that grants the functionality to highlight different text and pretty print it.""" + + def __init__(self, mode="dates", file_path=None): + """A unified class for highlighting text. + + Args: + mode (str): One of ["dates", "file", "html"] — determines the highlight type. + file_path (str, optional): Path to a file with highlight patterns if mode="file" or "html". + """ + self.file_path = file_path + self.mode = mode + + if self.mode in ("html", "file") and not self.file_path: + raise ValueError(f"file_path must be provided when mode='{self.mode}'") + + if self.mode == "dates": + self.highlight_func = self._highlight_dates + elif self.mode == "html": + self.highlight_func = self._highlight_html() + elif self.mode == "file": + self.highlight_func = self._highlight_from_file() + else: + raise ValueError( + f"Unknown highlight mode: {self.mode}, should be in 'dates', 'file', 'html'" + ) + + def pretty_print(self, text: str): + """Prints a single text with optional highlighting. + + Args: + text (str): The text to print. + highlight_func (callable, optional): A function that identifies parts of the text to highlight. + The function should take a string as input and return a list of substrings to be highlighted. + """ + if self.highlight_func: + matches = self.highlight_func(text) + for match in matches: + text = text.replace(match, colored(match, "green")) + print(text) + print("-" * 40) + + def pretty_print_dataset(self, dataset, n=5, label=None): + """Prints up to n examples of the dataset with optional highlighting. + + If a label is provided, only examples with that label are printed. + + Args: + dataset: A dataset containing text and labels. + n (int): Maximum number of examples to print (default is 5). + label (int, optional): If provided, only examples with this label are printed. + """ + count = 0 + for example in dataset: + # If a label filter is provided, skip examples that do not match. + if label is not None and example["labels"] != label: + continue + + print(f"Text {count + 1} (Label={example['labels']}):") + self.pretty_print(example["text"]) + count += 1 + if count >= n: + break + + def _highlight_dates(self, text): + """Finds all date patterns in the text in the format YYYY-MM-DD. + + Args: + text (str): The text to search. + + Returns: + list: A list of date strings found in the text. + """ + return re.findall(r"\d{4}-\d{2}-\d{2}", text) + + def _highlight_from_file(self): + """Reads patterns from a file and returns a highlight function that highlights these patterns in the text. + + Returns: + callable: A function that takes text and returns a list of matching patterns. + """ + with open(self.file_path, "r", encoding="utf-8") as file: + patterns = [line.strip() for line in file if line.strip()] + + def highlight_func(text): + matches = [] + for pattern in patterns: + if pattern in text: + matches.append(pattern) + return matches + + return highlight_func + + def _highlight_html(self): + """Reads HTML tag patterns from a file and returns a highlight function that highlights these tags in the text. + + Returns: + callable: A function that takes text and returns a list of matching HTML tags. + """ + with open(self.file_path, "r", encoding="utf-8") as file: + patterns = [line.strip() for line in file if line.strip()] + tags = [] + for line in patterns: + tags.extend(line.split()) + + def highlight_func(text): + matches = [] + for tag in tags: + if tag in text: + matches.append(tag) + return matches + + return highlight_func diff --git a/stable_pretraining/data/spurious_dataset.py b/stable_pretraining/data/spurious_dataset.py new file mode 100644 index 00000000..59942d53 --- /dev/null +++ b/stable_pretraining/data/spurious_dataset.py @@ -0,0 +1,67 @@ +"""spurious_dataset.py. + +Unified module for constructing spurious correlation datasets. + +All spurious injections are now file-based via ItemInjection.from_file or similar. +""" + +import random +from datasets import concatenate_datasets + +# The modifiers come from your modifiers.py file +from transforms import CompositeModifier + + +class SpuriousDatasetBuilder: + """Builds datasets with spurious correlations by applying Modifier objects.""" + + def __init__(self, seed=None): + """Constructor for the DatasetBuilder. + + Args: + seed (int, optional): Seed for reproducibility. + """ + self.rng = random.Random(seed) + + def _apply_modifier_to_subset(self, dataset, label_to_modify, modifier, proportion): + """Apply a Modifier (or CompositeModifier) to a proportion of samples with a given label.""" + dataset_to_modify = dataset.filter(lambda ex: ex["labels"] == label_to_modify) + remaining_dataset = dataset.filter(lambda ex: ex["labels"] != label_to_modify) + + n_examples = len(dataset_to_modify) + n_to_modify = round(n_examples * proportion) + indices = list(range(n_examples)) + selected_indices = set(self.rng.sample(indices, n_to_modify)) + + def modify_example(example, idx): + if idx in selected_indices: + new_text, new_label = modifier(example["text"], example["labels"]) + example["text"] = new_text + example["labels"] = new_label + return example + + modified_subset = dataset_to_modify.map(modify_example, with_indices=True) + return concatenate_datasets([modified_subset, remaining_dataset]) + + def build_spurious_dataset( + self, dataset, modifiers_config, label_to_modify, proportion + ): + """Construct a spurious dataset. + + Args: + dataset: Hugging Face Dataset object with "text" and "labels". + modifiers_config (list[Modifier] or Modifier): One or more modifiers to apply. + label_to_modify (int): Which label group to modify. + proportion (float): Proportion of examples within that label to modify (0-1). + + Returns: + Dataset: Modified dataset. + """ + if isinstance(modifiers_config, list): + modifier = CompositeModifier(modifiers_config) + else: + modifier = modifiers_config + + return self._apply_modifier_to_subset( + dataset, label_to_modify, modifier, proportion + ) diff --git a/stable_pretraining/data/transforms.py b/stable_pretraining/data/transforms.py index b97a1828..a9497c89 100644 --- a/stable_pretraining/data/transforms.py +++ b/stable_pretraining/data/transforms.py @@ -2,6 +2,8 @@ from itertools import islice from random import getstate, setstate from random import seed as rseed +import random +import re from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import numpy as np @@ -14,11 +16,18 @@ from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2._utils import query_chw +from torchvision.io import read_image +from torchvision.transforms.functional import resize from PIL import Image from stable_pretraining.data.masking import multi_block_mask +# ============================================================ +# ===================== Images =============================== +# ============================================================ + + class Transform(v2.Transform): """Base transform class extending torchvision v2.Transform with nested data handling.""" @@ -1088,3 +1097,470 @@ def __call__(self, x): # else: # sample[self.new_key] = sample[self.label_key] # return sample + +# ------------------------------------------------------------------------------------------------------------- +# Spurious Text Transforms + + +class AddSampleIdx(Transform): + """Add an "idx" key each sample to allow for deterministic injection.""" + + def __init__(self): + super().__init__() + self._counter = 0 + + def __call__(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + if "idx" not in x: + x["idx"] = self._counter + self._counter += 1 + + return x + + +class ClassConditionalInjector(Transform): + """Applies transformations conditionally based on sample label. + + Args: + transformation (Transform): Transform to apply to the image. + label_key (str): Key for label in the sample dict. + target_labels (Union[int, list[int]]): Which labels to modify. + proportion (float): Fraction of samples with matching labels to modify (0-1). + total_samples (int, optional): Dataset size (for deterministic mask). + seed (int): Seed for randomization to determine which samples transformation is applied to + """ + + def __init__( + self, + transformation: Transform, + label_key: str = "label", + target_labels: Union[int, list[int]] = 0, + proportion: float = 0.5, + total_samples: Optional[int] = None, + seed: int = 42, + ): + super().__init__() + self.transformation = transformation + self.label_key = label_key + self.target_labels = ( + [target_labels] if isinstance(target_labels, int) else target_labels + ) + self.proportion = proportion + self.total_samples = total_samples + self.seed = seed + + # Precompute deterministic mask if dataset size known + if total_samples is not None: + num_to_transform = int(total_samples * proportion) + rng = torch.Generator().manual_seed(seed) + self.indices_to_transform = set( + torch.randperm(total_samples, generator=rng)[:num_to_transform].tolist() + ) + else: + self.indices_to_transform = None + + def __call__(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + label = self.nested_get(x, self.label_key) + + # Determine if we apply the transformation + should_transform = False + idx = self.nested_get(x, "idx") + if label in self.target_labels: + if self.indices_to_transform is not None: + should_transform = idx in self.indices_to_transform + else: + should_transform = random.random() < self.proportion + + if should_transform: + x = self.transformation(x) + + return x + + +class SpuriousTextInjection(Transform): + """Injects spurious tokens into text for specific target classes. + + Args: + text_key (str): The name of the key representing the text in the dataset + class_key (str): The name of the key representing the label in the dataset + class_target (int): The label to have the spurious correlation injected into + file_path (str): The path of the file to inject spurious correlations from + p (float): The proportion of samples to inject the spurious token into + location (str): The location of the text to inject the spurious token(s) into + token_proportion (float): The proportion of the original tokens available in the dataset to inject as spurious tokens + (used to determine the number injected per sample) + seed (int): Seed for reproducibility + """ + + def __init__( + self, + text_key: str, + file_path: str, + location: str = "random", + token_proportion: float = 0.1, + seed: int = None, + ): + self.text_key = text_key + self.location = location + self.token_proportion = token_proportion + self.base_seed = seed + # store RNG per idx + self.rngs = {} + + with open(file_path, "r", encoding="utf-8") as f: + self.items = [line.strip() for line in f if line.strip()] + + assert self.items, f"No valid lines found in {file_path}" + assert 0 <= self.token_proportion <= 1, "token_proportion must be in [0, 1]" + assert self.location in {"beginning", "random", "end"} + + def _get_rng(self, idx): + if idx not in self.rngs: + seed = self.base_seed + idx if self.base_seed is not None else None + self.rngs[idx] = random.Random(seed) + return self.rngs[idx] + + def _inject(self, text: str, rng: random.Random) -> str: + words = text.split() + num_tokens = len(words) + num_to_inject = max(1, int(num_tokens * self.token_proportion)) + injections = [rng.choice(self.items) for _ in range(num_to_inject)] + + if self.location == "beginning": + words = injections + words + elif self.location == "end": + words = words + injections + elif self.location == "random": + for inj in injections: + pos = rng.randint(0, len(words)) + words.insert(pos, inj) + return " ".join(words) + + def __call__(self, x: dict) -> dict: + text = x[self.text_key] + + # Deterministic RNG per call + if self.base_seed is not None: + idx = x.get("idx", 0) + rng = self._get_rng(idx) + else: + rng = random.Random() + + x[self.text_key] = self._inject(text, rng) + return x + + +class HTMLInjection(Transform): + """Injects HTML-like tags into text fields (deterministically if 'idx' present). + + This transform adds artificial HTML tokens to text data, optionally at a specific + HTML nesting level or a random position. Supports deterministic per-sample + injection when used with AddSampleIdx and ClassConditionalInjector. + + Args: + text_key (str): Key for the text field in the dataset sample. + file_path (str): Path to file containing HTML tags (each line = one tag or tag pair). + location (str): Where to inject tags ("beginning", "end", or "random"). + level (int, optional): Target HTML nesting level to inject within. + token_proportion (float, optional): The proportion of the original tokens available in the dataset + to inject as spurious tokens (used to determine the number injected per sample) + seed (int, optional): Random seed for reproducibility. + """ + + def __init__( + self, + text_key: str, + file_path: str, + location: str = "random", + level: Optional[int] = None, + token_proportion: Optional[float] = None, + seed: Optional[int] = None, + ): + super().__init__() + self.text_key = text_key + self.location = location + self.level = level + self.token_proportion = token_proportion + self.base_seed = seed if seed is not None else 0 + self.rng = random.Random(seed) + + with open(file_path, "r", encoding="utf-8") as f: + self.tags = [line.strip() for line in f if line.strip()] + + assert self.tags, f"No valid tags found in {file_path}" + if token_proportion is not None: + assert 0 < token_proportion <= 1, "token_proportion must be between 0 and 1" + assert self.location in {"beginning", "end", "random"}, "invalid location" + + # ---------- Internal helpers ---------- + def _choose_tag(self, rng): + """Select an opening/closing tag pair.""" + line = rng.choice(self.tags) + parts = line.split() + if len(parts) >= 2: + return parts[0], parts[1] + else: + return parts[0], None + + def _inject_with_tags(self, tokens, opening, closing, location, rng): + """Inject the a single tag into the text.""" + new_tokens = tokens[:] + if location == "beginning": + new_tokens.insert(0, opening) + if closing: + pos = rng.randint(1, len(new_tokens)) + new_tokens.insert(pos, closing) + elif location == "end": + pos = rng.randint(0, len(new_tokens)) + new_tokens.insert(pos, opening) + if closing: + new_tokens.append(closing) + elif location == "random": + pos_open = rng.randint(0, len(new_tokens)) + new_tokens.insert(pos_open, opening) + if closing: + pos_close = rng.randint(pos_open + 1, len(new_tokens)) + new_tokens.insert(pos_close, closing) + return new_tokens + + def _inject(self, text, rng): + """Overall injection of all tags into the text.""" + tokens = text.split() + if not tokens: + return text + + if self.token_proportion is None: + opening, closing = self._choose_tag(rng) + tokens = self._inject_with_tags( + tokens, opening, closing, self.location, rng + ) + else: + n = len(tokens) + num_insertions = max(1, int(n * self.token_proportion)) + for _ in range(num_insertions): + opening, closing = self._choose_tag(rng) + tokens = self._inject_with_tags( + tokens, opening, closing, self.location, rng + ) + return " ".join(tokens) + + def _inject_at_level(self, text, level, rng): + """Inject tags inside a specific HTML nesting level.""" + tag_regex = re.compile(r"]*>") + stack = [] + for match in tag_regex.finditer(text): + tag_str = match.group(0) + tag_name = match.group(1) + if not tag_str.startswith(" Dict[str, Any]: + """Main call function for the transformation.""" + text = x[self.text_key] + + # Deterministic per-sample RNG if idx available + if "idx" in x: + seed = self.base_seed + int(x["idx"]) + rng = random.Random(seed) + # fallback (non-deterministic but seeded globally) + else: + rng = self.rng + + if self.level is None: + x[self.text_key] = self._inject(text, rng) + else: + x[self.text_key] = self._inject_at_level(text, self.level, rng) + + return x + + +# ------------------------------------------------------------------------------------------------------------- +# Spurious Image Transforms + + +class AddPatch(Transform): + """Add a solid color patch to an image at a fixed position. + + Args: + patch_size (float): Fraction of image width/height for the patch (0 < patch_size ≤ 1). + color (Tuple[float, float, float]): RGB values in [0, 1]. + position (str): Where to place the patch: 'top_left_corner', 'top_right_corner', + 'bottom_left_corner', 'bottom_right_corner', 'center'. + """ + + def __init__( + self, + patch_size: float = 0.1, + color: Tuple[float, float, float] = (1.0, 0.0, 0.0), + position: str = "bottom_right_corner", + ): + super().__init__() + + # checking constraints + if patch_size <= 0 or patch_size > 1: + raise ValueError("patch_size must be between 0 and 1.") + + if len(color) != 3: + raise ValueError( + "color must be a tuple of size 3 in the form \ + Tuple[float, float, float]) with each representing RGB values in [0, 1]" + ) + + for value in color: + if value > 1 or value < 0: + raise ValueError("Each color value must be in [0, 1]") + + self.patch_size = patch_size + self.color = color + self.position = position + + def __call__(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + img = self.nested_get(x, "image") + _, H, W = img.shape + + patch_h = int(H * self.patch_size) + patch_w = int(W * self.patch_size) + + # Create a colored patch + patch = torch.zeros((3, patch_h, patch_w), device=img.device) + patch[0] = self.color[0] + patch[1] = self.color[1] + patch[2] = self.color[2] + + img = img.clone() + if self.position == "top_left_corner": + img[:, :patch_h, :patch_w] = patch + elif self.position == "top_right_corner": + img[:, :patch_h, -patch_w:] = patch + elif self.position == "bottom_left_corner": + img[:, -patch_h:, :patch_w] = patch + elif self.position == "bottom_right_corner": + img[:, -patch_h:, -patch_w:] = patch + elif self.position == "center": + center_y, center_x = H // 2, W // 2 + img[ + :, + center_y - patch_h // 2 : center_y + patch_h // 2, + center_x - patch_w // 2 : center_x + patch_w // 2, + ] = patch + else: + raise ValueError( + f"Invalid position: {self.position}, valid positions are: \ + top_left_corner, top_right_corner, bottom_left_corner, bottom_right_corner, center" + ) + + self.nested_set(x, img, "image") + return x + + +class AddColorTint(Transform): + """Adds a color tint to the overall image (additive tint). + + Args: + tint (Tuple[float, float, float]): RGB representation of the tint that will be applied to the overall image + alpha (Float): mixing ratio for how much to blend the new color with the existing image + """ + + def __init__( + self, tint: Tuple[float, float, float] = (1.0, 0.8, 0.8), alpha: float = 0.3 + ): + super().__init__() + self.tint = torch.tensor(tint).view(3, 1, 1) + self.alpha = alpha + + def __call__(self, x): + img = self.nested_get(x, "image") + img = torch.clamp(img * (1 - self.alpha) + self.tint * self.alpha, 0, 1) + self.nested_set(x, img, "image") + return x + + +class AddBorder(Transform): + """Adds a border around an image. + + Args: + thickness (Float): how thick the border around the image will be + color (Tuple[float, float, float]): RGB representation of the color of the border + """ + + def __init__( + self, thickness: float = 0.05, color: Tuple[float, float, float] = (0, 1, 0) + ): + super().__init__() + self.thickness = thickness + self.color = color + + def __call__(self, x): + img = self.nested_get(x, "image").clone() + _, H, W = img.shape + + # scale to match image size + t = int(min(H, W) * self.thickness) + color_tensor = torch.tensor(self.color, device=img.device).view(3, 1, 1) + + img[:, :t, :] = color_tensor + img[:, -t:, :] = color_tensor + img[:, :, :t] = color_tensor + img[:, :, -t:] = color_tensor + self.nested_set(x, img, "image") + + return x + + +class AddWatermark(Transform): + """Overlay another image (logo, emoji, etc.) onto the base image. + + Args: + watermark_path (str): Path to the watermark image (e.g. 'smile.png'). + size (float): Fraction of base image size to scale watermark. + position (str): One of ['top_left', 'top_right', 'bottom_left', 'bottom_right', 'center']. + alpha (float): Opacity of watermark (0-1). + """ + + def __init__(self, watermark_path, size=0.2, position="bottom_right", alpha=0.8): + super().__init__() + # [C,H,W] tensor in [0,1] + self.watermark = read_image(watermark_path).float() / 255.0 + self.size = size + self.position = position + self.alpha = alpha + + def __call__(self, x): + img = self.nested_get(x, "image").clone() + _, H, W = img.shape + + # Resize watermark + w_h, w_w = self.watermark.shape[1:] + target_h = int(H * self.size) + target_w = int(w_w / w_h * target_h) + wm = resize(self.watermark, [target_h, target_w]) + + # Compute position + if self.position == "top_left": + y0, x0 = 0, 0 + elif self.position == "top_right": + y0, x0 = 0, W - target_w + elif self.position == "bottom_left": + y0, x0 = H - target_h, 0 + elif self.position == "bottom_right": + y0, x0 = H - target_h, W - target_w + elif self.position == "center": + y0, x0 = (H - target_h) // 2, (W - target_w) // 2 + else: + raise ValueError(f"Unknown position: {self.position}") + + background_region = img[:, y0 : y0 + target_h, x0 : x0 + target_w] + img[:, y0 : y0 + target_h, x0 : x0 + target_w] = ( + background_region * (1 - self.alpha) + wm * self.alpha + ) + + self.nested_set(x, img, "image") + return x diff --git a/stable_pretraining/data/utils.py b/stable_pretraining/data/utils.py index 851653f5..f5baf982 100644 --- a/stable_pretraining/data/utils.py +++ b/stable_pretraining/data/utils.py @@ -6,6 +6,8 @@ import itertools import math +import random +import calendar import warnings from collections.abc import Sequence from typing import Optional, Union, cast @@ -155,3 +157,43 @@ def apply_masks(x: torch.Tensor, *masks: torch.Tensor) -> torch.Tensor: out = x_expanded.gather(2, idx_expanded) return out.reshape(B * M, K, D) + + +def write_random_dates( + file_path, n=1000, year_range=(1100, 2600), seed=None, with_replacement=False +): + """Writes `n` random valid dates to a text file, one per line. + + Args: + file_path (str): Destination file path. + n (int): Number of random dates to generate. + year_range (tuple): Range of valid years (start, end). + seed (int): Optional random seed for reproducibility. + with_replacement (bool): Whether to allow for dates to repeat + """ + rng = random.Random(seed) + start_year, end_year = year_range + + # Generate all possible dates in range + all_dates = [] + for year in range(start_year, end_year + 1): + for month in range(1, 13): + _, max_day = calendar.monthrange(year, month) + for day in range(1, max_day + 1): + all_dates.append(f"{year}-{month:02d}-{day:02d}") + + total_possible = len(all_dates) + if not with_replacement and n > total_possible: + raise ValueError( + f"Cannot generate {n} unique dates — only {total_possible} unique dates possible " + f"for year range {year_range}." + ) + + if with_replacement: + chosen = [rng.choice(all_dates) for _ in range(n)] + else: + chosen = rng.sample(all_dates, k=n) + + with open(file_path, "w", encoding="utf-8") as f: + for d in chosen: + f.write(d + "\n") diff --git a/stable_pretraining/tests/unit/test_file_based_sampling.py b/stable_pretraining/tests/unit/test_file_based_sampling.py new file mode 100644 index 00000000..b94699fc --- /dev/null +++ b/stable_pretraining/tests/unit/test_file_based_sampling.py @@ -0,0 +1,76 @@ +import pytest +import tempfile +import os +from stable_pretraining.data.utils import write_random_dates +from stable_pretraining.data.transforms import SpuriousTextInjection + + +@pytest.mark.unit +def test_write_random_dates_creates_file(): + with tempfile.TemporaryDirectory() as tmpdir: + out_file = os.path.join(tmpdir, "dates.txt") + write_random_dates(out_file, n=10, seed=42, with_replacement=False) + + assert os.path.exists(out_file) + with open(out_file, "r") as f: + lines = [line.strip() for line in f if line.strip()] + + assert len(lines) == 10 + assert all("-" in d for d in lines) # looks like YYYY-MM-DD + + +@pytest.mark.unit +def test_spurious_text_injection_reads_from_file_and_injects_correctly(): + with tempfile.TemporaryDirectory() as tmpdir: + # Create fake spurious tokens file + file_path = os.path.join(tmpdir, "tokens.txt") + with open(file_path, "w") as f: + f.write("A\nB\nC\n") + + # Create transform + transform = SpuriousTextInjection( + text_key="text", + file_path=file_path, + location="random", + token_proportion=0.5, + seed=42, + ) + + sample = {"text": "original text", "label": 0} + output = transform(sample) + + assert isinstance(output["text"], str) + assert output["text"] != "original text" + assert any(tok in output["text"] for tok in ["A", "B", "C"]) + + +@pytest.mark.unit +def test_spurious_text_injection_is_deterministic_with_seed(): + # Create one file with spurious tokens + with tempfile.TemporaryDirectory() as tmpdir: + file_path = os.path.join(tmpdir, "tokens.txt") + with open(file_path, "w") as f: + f.write("X\nY\nZ\n") + + # Create two transforms reading from the same file + t1 = SpuriousTextInjection( + text_key="text", + file_path=file_path, + location="end", + token_proportion=0.5, + seed=123, + ) + t2 = SpuriousTextInjection( + text_key="text", + file_path=file_path, + location="end", + token_proportion=0.5, + seed=123, + ) + + sample1 = {"text": "base text", "label": 1} + sample2 = {"text": "base text", "label": 1} + outputs1 = [t1(sample1)["text"] for _ in range(5)] + outputs2 = [t2(sample2)["text"] for _ in range(5)] + + assert outputs1 == outputs2, "Should produce identical results with same seed" diff --git a/stable_pretraining/tests/unit/test_html_injection.py b/stable_pretraining/tests/unit/test_html_injection.py new file mode 100644 index 00000000..1700fc3f --- /dev/null +++ b/stable_pretraining/tests/unit/test_html_injection.py @@ -0,0 +1,82 @@ +import pytest +from stable_pretraining.data.transforms import HTMLInjection + + +@pytest.mark.unit +def test_html_injection_deterministic_same_idx(tmp_path): + tag_path = tmp_path / "tags.txt" + tag_path.write_text(" \n") + + # Create deterministic injection transform + modifier = HTMLInjection( + file_path=str(tag_path), + text_key="text", + location="random", + token_proportion=0.5, + seed=123, + ) + + # Two samples with the same idx should yield identical results + sample1 = {"text": "consistent sample text", "label": "lbl", "idx": 10} + sample2 = {"text": "consistent sample text", "label": "lbl", "idx": 10} + + out1 = modifier(sample1) + out2 = modifier(sample2) + + assert out1["text"] == out2["text"], "Same idx should produce identical injection" + assert out1["label"] == out2["label"] + + +@pytest.mark.unit +def test_html_injection_deterministic_different_idx(tmp_path): + tag_path = tmp_path / "tags.txt" + tag_path.write_text(" \n") + + modifier = HTMLInjection( + file_path=str(tag_path), + text_key="text", + location="random", + token_proportion=0.5, + seed=123, + ) + + # Two samples with different idx should yield different results + sample1 = {"text": "different sample text", "label": "lbl", "idx": 1} + sample2 = {"text": "different sample text", "label": "lbl", "idx": 2} + + out1 = modifier(sample1) + out2 = modifier(sample2) + + assert out1["text"] != out2["text"], ( + "Different idx should produce different injections" + ) + + +@pytest.mark.unit +def test_html_injection_deterministic_reproducibility_across_runs(tmp_path): + tag_path = tmp_path / "tags.txt" + tag_path.write_text("

\n") + + sample = {"text": "check reproducibility", "label": "lbl", "idx": 42} + + modifier1 = HTMLInjection( + file_path=str(tag_path), + text_key="text", + location="random", + token_proportion=0.5, + seed=777, + ) + modifier2 = HTMLInjection( + file_path=str(tag_path), + text_key="text", + location="random", + token_proportion=0.5, + seed=777, + ) + + out1 = modifier1(sample) + out2 = modifier2(sample) + + assert out1["text"] == out2["text"], ( + "Same base seed and idx must yield identical output" + ) diff --git a/stable_pretraining/tests/unit/test_item_injection.py b/stable_pretraining/tests/unit/test_item_injection.py new file mode 100644 index 00000000..c66c0fc4 --- /dev/null +++ b/stable_pretraining/tests/unit/test_item_injection.py @@ -0,0 +1,116 @@ +import pytest +from stable_pretraining.data.transforms import SpuriousTextInjection, AddSampleIdx + + +@pytest.mark.unit +def test_spurious_text_injection_deterministic_same_idx(tmp_path): + # Create dummy spurious tokens file + src_path = tmp_path / "spurious.txt" + src_path.write_text("RED\nGREEN\nBLUE\n") + + src_path2 = tmp_path / "spurious2.txt" + src_path2.write_text("RED\nGREEN\nBLUE\n") + + text = "deterministic spurious injection test" + transform1 = SpuriousTextInjection( + text_key="text", + file_path=str(src_path), + location="random", + token_proportion=0.5, + seed=123, + ) + + transform2 = SpuriousTextInjection( + text_key="text", + file_path=str(src_path2), + location="random", + token_proportion=0.5, + seed=123, + ) + + sample1 = {"text": text, "label": "A", "idx": 5} + sample2 = {"text": text, "label": "A", "idx": 5} + + out1 = transform1(sample1) + out2 = transform2(sample2) + + assert out1["text"] == out2["text"], "Same idx should produce identical injection" + assert out1["label"] == out2["label"] + + +@pytest.mark.unit +def test_spurious_text_injection_deterministic_different_idx(tmp_path): + src_path = tmp_path / "spurious.txt" + src_path.write_text("HELLO\nWORLD\n") + + text = "check for idx-dependent difference" + transform = SpuriousTextInjection( + text_key="text", + file_path=str(src_path), + location="random", + token_proportion=0.5, + seed=321, + ) + + sample1 = {"text": text, "label": "B", "idx": 1} + sample2 = {"text": text, "label": "B", "idx": 2} + + out1 = transform(sample1) + out2 = transform(sample2) + + assert out1["text"] != out2["text"], "Different idx should yield different outputs" + + +@pytest.mark.unit +def test_spurious_text_injection_reproducibility_across_runs(tmp_path): + src_path = tmp_path / "tokens.txt" + src_path.write_text("A\nB\nC\n") + + text = "reproducibility test for spurious injection" + sample = {"text": text, "label": "C", "idx": 42} + + transform1 = SpuriousTextInjection( + text_key="text", + file_path=str(src_path), + location="end", + token_proportion=0.25, + seed=999, + ) + transform2 = SpuriousTextInjection( + text_key="text", + file_path=str(src_path), + location="end", + token_proportion=0.25, + seed=999, + ) + + out1 = transform1(sample) + out2 = transform2(sample) + + assert out1["text"] == out2["text"], ( + "Same seed and idx should yield identical injection" + ) + + +@pytest.mark.unit +def test_spurious_text_injection_with_addsampleidx(tmp_path): + src_path = tmp_path / "spurious.txt" + src_path.write_text("NOISE\nTAG\n") + + add_idx = AddSampleIdx() + transform = SpuriousTextInjection( + text_key="text", + file_path=str(src_path), + location="beginning", + token_proportion=0.3, + seed=777, + ) + + sample = {"text": "verify deterministic pipeline", "label": "Y"} + sample = add_idx(sample) + out = transform(sample) + + assert "idx" in sample, "AddSampleIdx should add an 'idx' field" + assert any(t in out["text"] for t in ["NOISE", "TAG"]), ( + "Should inject spurious token" + ) diff --git a/stable_pretraining/tests/unit/test_transforms.py b/stable_pretraining/tests/unit/test_transforms.py index 90d277f5..fc098095 100644 --- a/stable_pretraining/tests/unit/test_transforms.py +++ b/stable_pretraining/tests/unit/test_transforms.py @@ -91,3 +91,94 @@ def test_transform_params_initialization(self): for t in transforms_to_test: assert t is not None + + # --------------------------- + # Spurious correlation tests + # --------------------------- + + def test_add_sample_idx_transform(self): + """Test that AddSampleIdx correctly increments indices.""" + transform = transforms.AddSampleIdx() + x1 = {"image": torch.zeros(3, 32, 32)} + x2 = {"image": torch.zeros(3, 32, 32)} + out1 = transform(x1) + out2 = transform(x2) + assert out1["idx"] == 0 + assert out2["idx"] == 1 + + def test_add_patch_transform(self): + """Test that AddPatch overlays a colored patch.""" + img = torch.zeros(3, 32, 32) + data = {"image": img.clone()} + transform = transforms.AddPatch( + patch_size=0.25, color=(1.0, 0.0, 0.0), position="top_left_corner" + ) + result = transform(data) + # Top-left corner should now contain red pixels + patch_area = result["image"][:, :8, :8] + assert torch.allclose(patch_area[0], torch.ones_like(patch_area[0]), atol=1e-3) + assert torch.allclose( + patch_area[1:], torch.zeros_like(patch_area[1:]), atol=1e-3 + ) + + def test_add_color_tint_transform(self): + """Test AddColorTint applies an additive tint.""" + img = torch.zeros(3, 16, 16) + data = {"image": img} + transform = transforms.AddColorTint(tint=(1.0, 0.5, 0.5), alpha=0.5) + result = transform(data) + # Image should not be all zeros anymore + assert torch.any(result["image"] > 0) + + def test_add_border_transform(self): + """Test AddBorder draws a colored border.""" + img = torch.zeros(3, 20, 20) + data = {"image": img} + transform = transforms.AddBorder(thickness=0.1, color=(0, 1, 0)) + result = transform(data) + # Corners should have green (0,1,0) + assert torch.allclose(result["image"][1, 0, 0], torch.tensor(1.0), atol=1e-3) + assert torch.allclose(result["image"][0, 0, 0], torch.tensor(0.0), atol=1e-3) + + def test_add_watermark_transform(self, tmp_path): + """Test AddWatermark overlays another image.""" + # Create a dummy watermark (white square) + wm_path = tmp_path / "wm.png" + from torchvision.utils import save_image + + save_image(torch.ones(3, 8, 8), wm_path) + data = {"image": torch.zeros(3, 32, 32)} + transform = transforms.AddWatermark( + str(wm_path), size=0.25, position="center", alpha=1.0 + ) + result = transform(data) + # There should be a bright region in the center + center = result["image"][:, 12:20, 12:20] + assert torch.mean(center) > 0.5 + + def test_class_conditional_injector(self): + """Test ClassConditionalInjector applies transform to correct labels only.""" + base_transform = transforms.AddPatch(color=(0, 1, 0)) + injector = transforms.ClassConditionalInjector( + transformation=base_transform, + target_labels=[1], + proportion=1.0, + total_samples=5, + seed=42, + ) + + # Prepare samples with idx + label + samples = [ + {"image": torch.zeros(3, 16, 16), "label": torch.tensor(label), "idx": idx} + for idx, label in enumerate([0, 1, 1, 0, 1]) + ] + + outputs = [injector(s) for s in samples] + + # Check that only samples with label of 1 were modified + for s_in, s_out in zip(samples, outputs): + mean_pixel = s_out["image"].mean().item() + if s_in["label"] == 1: + assert mean_pixel > 0 # patch added + else: + assert mean_pixel == 0 # unchanged diff --git a/stable_pretraining/tests/unit/test_write_random_dates.py b/stable_pretraining/tests/unit/test_write_random_dates.py new file mode 100644 index 00000000..9f53babf --- /dev/null +++ b/stable_pretraining/tests/unit/test_write_random_dates.py @@ -0,0 +1,80 @@ +import pytest +from pathlib import Path +from stable_pretraining.data.utils import write_random_dates + + +@pytest.mark.unit +def test_no_duplicates_with_replacement_false(tmp_path: Path): + """Ensure write_random_dates produces unique dates when with_replacement=False.""" + output_file = tmp_path / "dates.txt" + write_random_dates( + output_file, n=365, year_range=(2020, 2020), seed=123, with_replacement=False + ) + + dates = [line.strip() for line in open(output_file)] + assert len(dates) == len(set(dates)), "Duplicates found when with_replacement=False" + + +@pytest.mark.unit +def test_with_replacement_allows_duplicates(tmp_path: Path): + """Ensure write_random_dates allows duplicates when with_replacement=True.""" + output_file = tmp_path / "dates.txt" + write_random_dates( + output_file, n=1000, year_range=(2020, 2020), seed=42, with_replacement=True + ) + + dates = [line.strip() for line in open(output_file)] + # With replacement, duplicates should appear in large samples + assert len(set(dates)) < len(dates), ( + "No duplicates found when with_replacement=True" + ) + + +@pytest.mark.unit +def test_same_seed_produces_same_output(tmp_path: Path): + """Ensure deterministic output for same seed.""" + f1 = tmp_path / "dates1.txt" + f2 = tmp_path / "dates2.txt" + + write_random_dates( + f1, n=100, year_range=(1900, 1905), seed=42, with_replacement=True + ) + write_random_dates( + f2, n=100, year_range=(1900, 1905), seed=42, with_replacement=True + ) + + d1 = open(f1).read().splitlines() + d2 = open(f2).read().splitlines() + + assert d1 == d2, "Outputs differ for same seed" + + +@pytest.mark.unit +def test_different_seed_produces_different_output(tmp_path: Path): + """Ensure non-deterministic output for different seeds.""" + f1 = tmp_path / "dates1.txt" + f2 = tmp_path / "dates2.txt" + + write_random_dates( + f1, n=100, year_range=(1900, 1905), seed=42, with_replacement=True + ) + write_random_dates( + f2, n=100, year_range=(1900, 1905), seed=99, with_replacement=True + ) + + d1 = open(f1).read().splitlines() + d2 = open(f2).read().splitlines() + + assert d1 != d2, "Outputs identical for different seeds" + + +@pytest.mark.unit +def test_number_of_lines_written(tmp_path: Path): + """Ensure exactly n lines are written.""" + output_file = tmp_path / "dates.txt" + n = 50 + write_random_dates( + output_file, n=n, year_range=(2020, 2020), seed=0, with_replacement=False + ) + lines = open(output_file).read().splitlines() + assert len(lines) == n, f"Expected {n} lines, got {len(lines)}"