diff --git a/RELEASES.rst b/RELEASES.rst
index 494fbd50..a76a30d1 100644
--- a/RELEASES.rst
+++ b/RELEASES.rst
@@ -15,3 +15,4 @@ Version 0.1
- RankMe, LiDAR metrics to monitor training.
- Examples of extracting run data from WandB and utilizing it to create figures.
- Fixed a bug in the logging functionality.
+- Library for injecting spurious tokens into HuggingFace datasets.
diff --git a/examples/sample_spurious_injection_execution.py b/examples/sample_spurious_injection_execution.py
new file mode 100644
index 00000000..8aab3743
--- /dev/null
+++ b/examples/sample_spurious_injection_execution.py
@@ -0,0 +1,282 @@
+"""Demonstration of the spurious_corr library capabilities."""
+
+from stable_pretraining.data.spurious_corr.modifiers import (
+ ItemInjection,
+ HTMLInjection,
+ CompositeModifier,
+)
+from stable_pretraining.data.spurious_corr.generators import SpuriousDateGenerator
+from stable_pretraining.data.spurious_corr.utils import (
+ pretty_print,
+ pretty_print_dataset,
+ highlight_from_file,
+ highlight_from_list,
+ highlight_html,
+ highlight_dates,
+)
+from stable_pretraining.data.spurious_corr.transform import spurious_transform
+from datasets import load_dataset
+
+
+def print_section(title):
+ """Print a formatted section header."""
+ print("\n" + "=" * 60)
+ print(f" {title}")
+ print("=" * 60)
+
+
+def example_1_basic_date_injection():
+ """Example 1: Basic date injection at different locations."""
+ print_section("Example 1: Date Injection with SpuriousDateGenerator")
+ text = "Machine learning models require careful evaluation and testing procedures."
+ print(f"\nOriginal text: '{text}'\n")
+ print("-" * 50)
+
+ # Create date generator
+ date_gen = SpuriousDateGenerator(year_range=(1900, 2100), seed=42)
+
+ # Example 1a: Inject at beginning
+ modifier_start = ItemInjection.from_function(
+ injection_func=date_gen, location="beginning", token_proportion=0.4, seed=42
+ )
+
+ modified_text, _ = modifier_start(text, 1)
+ print("1a. Date injection at BEGINNING:")
+ pretty_print(modified_text, highlight_dates)
+ print("-" * 50)
+
+ # Example 1b: Inject at end
+ modifier_end = ItemInjection.from_function(
+ injection_func=date_gen, location="end", token_proportion=0.4, seed=43
+ )
+
+ modified_text, _ = modifier_end(text, 1)
+ print("1b. Date injection at END:")
+ pretty_print(modified_text, highlight_dates)
+ print("-" * 50)
+
+ # Example 1c: Inject at random locations
+ modifier_random = ItemInjection.from_function(
+ injection_func=date_gen, location="random", token_proportion=0.4, seed=44
+ )
+
+ modified_text, _ = modifier_random(text, 1)
+ print("1c. Date injection at RANDOM positions:")
+ pretty_print(modified_text, highlight_dates)
+ print("-" * 50)
+
+
+def example_2_file_based_injection():
+ """Example 2: Inject tokens from files (countries, colors)."""
+ print_section("Example 2: File-Based Token Injection")
+
+ # Example 2a: Country injection
+ country_modifier = ItemInjection.from_file(
+ file_path="examples/data/countries.txt",
+ location="random",
+ token_proportion=0.3,
+ seed=42,
+ )
+
+ country_highlighter = highlight_from_file("examples/data/countries.txt")
+ text = "International trade agreements benefit global economic stability."
+ modified_text, _ = country_modifier(text, 1)
+
+ print("2a. Country injection:")
+ pretty_print(modified_text, country_highlighter)
+ print("-" * 50)
+
+ # Example 2b: Color injection
+ color_modifier = ItemInjection.from_file(
+ file_path="examples/data/colors.txt",
+ location="random",
+ token_proportion=1,
+ seed=42,
+ )
+
+ color_highlighter = highlight_from_file("examples/data/colors.txt")
+ text = "The sunset painted the sky beautifully."
+ modified_text, _ = color_modifier(text, 1)
+
+ print("2b. Color injection:")
+ pretty_print(modified_text, color_highlighter)
+ print("-" * 50)
+
+ # Example 2c: Custom word list
+ custom_modifier = ItemInjection.from_list(
+ items=["URGENT", "BREAKING", "EXCLUSIVE", "ALERT"],
+ location="random",
+ token_proportion=1,
+ seed=42,
+ )
+
+ custom_highlighter = highlight_from_list(
+ ["URGENT", "BREAKING", "EXCLUSIVE", "ALERT"]
+ )
+ text = "Weather forecast predicts rain tomorrow."
+ modified_text, _ = custom_modifier(text, 1)
+
+ print("2c. Custom urgent words:")
+ pretty_print(modified_text, custom_highlighter)
+ print("-" * 50)
+
+
+def example_3_html_injection():
+ """Example 3: HTML tag injection with different strategies."""
+ print_section("Example 3: HTML Tag Injection")
+ html_highlighter = highlight_html("examples/data/html_tags.txt")
+ text = "This is an important announcement for all users."
+
+ # Example 3a: Single HTML tag at beginning
+ begin_modifier = HTMLInjection.from_file(
+ file_path="examples/data/html_tags.txt", location="beginning", seed=42
+ )
+
+ modified_text, _ = begin_modifier(text, 1)
+ print("3a. Beginning single HTML tag injection:")
+ pretty_print(modified_text, html_highlighter)
+ print("-" * 50)
+
+ # Example 3b: Single HTML tag at random location
+ random_modifier = HTMLInjection.from_file(
+ file_path="examples/data/html_tags.txt", location="random", seed=43
+ )
+
+ modified_text, _ = random_modifier(text, 1)
+ print("3b. Random single HTML tag injection:")
+ pretty_print(modified_text, html_highlighter)
+ print("-" * 50)
+
+ # Example 3c: Single HTML tag at end
+ end_modifier = HTMLInjection.from_file(
+ file_path="examples/data/html_tags.txt", location="end", seed=44
+ )
+
+ modified_text, _ = end_modifier(text, 1)
+ print("3c. End single HTML tag injection:")
+ pretty_print(modified_text, html_highlighter)
+ print("-" * 50)
+
+ # Example 3d: Multiple HTML tags at random locations
+ multi_random_modifier = HTMLInjection.from_file(
+ file_path="examples/data/html_tags.txt",
+ location="random",
+ token_proportion=0.5,
+ seed=45,
+ )
+
+ modified_text, _ = multi_random_modifier(text, 1)
+ print("3d. Multiple random HTML tag injection:")
+ pretty_print(modified_text, html_highlighter)
+ print("-" * 50)
+
+
+def example_4_multiple_injections():
+ """Example 4: Multiple different injection types combined."""
+ print_section("Example 4: Multiple Injection Types Combined")
+
+ # Date at beginning
+ date_modifier = ItemInjection.from_function(
+ SpuriousDateGenerator(year_range=(2020, 2024), seed=42),
+ location="beginning",
+ token_proportion=0,
+ )
+
+ # Country in middle
+ country_modifier = ItemInjection.from_file(
+ file_path="examples/data/countries.txt",
+ location="random",
+ token_proportion=0,
+ seed=43,
+ )
+
+ # Color at end
+ color_modifier = ItemInjection.from_file(
+ file_path="examples/data/colors.txt",
+ location="end",
+ token_proportion=0,
+ seed=44,
+ )
+
+ # Combine all
+ multi_modifier = CompositeModifier(
+ [date_modifier, country_modifier, color_modifier]
+ )
+
+ text = "Economic analysis shows promising trends in renewable energy sectors."
+ modified_text, _ = multi_modifier(text, 1)
+ print("4. Multiple injection types:")
+ print(modified_text)
+ print("-" * 50)
+
+
+def example_5_token_density_comparison():
+ """Example 5: Compare different token proportion levels."""
+ print_section("Example 5: Token Proportion Comparison")
+
+ text = "Artificial intelligence and machine learning technologies are transforming industries."
+ highlighter = highlight_dates
+
+ token_proportions = [0, 0.3, 0.5, 0.8, 1.0] # 0 injects a single token
+
+ for density in token_proportions:
+ modifier = ItemInjection.from_function(
+ SpuriousDateGenerator(year_range=(2020, 2024), seed=42),
+ location="random",
+ token_proportion=density,
+ seed=42,
+ )
+
+ modified_text, _ = modifier(text, 1)
+ print(f"\nToken proportion {density}:")
+ pretty_print(modified_text, highlighter)
+
+
+def example_6_dataset_simulation():
+ """Example 6: Simulate dataset-level spurious correlations."""
+ print_section(
+ "Example 6: Dataset-Level Spurious Correlation Simulation using spurious_transform"
+ )
+
+ # Load IMDB dataset
+ dataset = load_dataset("imdb", split="train") # Load full training dataset
+
+ # Create date modifier
+ date_modifier = ItemInjection.from_function(
+ SpuriousDateGenerator(year_range=(2020, 2024), seed=42),
+ location="random",
+ token_proportion=0.1,
+ seed=42,
+ )
+
+ print("Simulating spurious correlation: Add dates to positive reviews only\n")
+
+ # Apply spurious transformation
+ modified_dataset = spurious_transform(
+ label_to_modify=1, # Target positive reviews
+ dataset=dataset,
+ modifier=date_modifier,
+ text_proportion=1.0, # Apply to all positive reviews
+ seed=42,
+ )
+
+ # Print examples using pretty_print_dataset
+ print("Positive reviews (with injected dates):")
+ pretty_print_dataset(modified_dataset, n=3, highlight_func=highlight_dates, label=1)
+
+ print("\nNegative reviews (original):")
+ pretty_print_dataset(modified_dataset, n=3, highlight_func=highlight_dates, label=0)
+
+
+def main():
+ """Run all examples demonstrating library capabilities."""
+ example_1_basic_date_injection()
+ example_2_file_based_injection()
+ example_3_html_injection()
+ example_4_multiple_injections()
+ example_5_token_density_comparison()
+ example_6_dataset_simulation()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/pyproject.toml b/pyproject.toml
index 828abdfd..871925b9 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -69,6 +69,7 @@ datasets = [
"datasets", # HuggingFace datasets
"pyarrow==20.0.0", # Required for datasets compatibility
"minari[hdf5]>=0.5.3", # Reinforcement learning datasets
+ "termcolor", # Visualizing spurious correlations
]
# Additional utilities
diff --git a/stable_pretraining/data/pprint.py b/stable_pretraining/data/pprint.py
new file mode 100644
index 00000000..40c74d76
--- /dev/null
+++ b/stable_pretraining/data/pprint.py
@@ -0,0 +1,124 @@
+"""pprint.py.
+
+This module provides utility functions for pretty-printing dataset examples and highlighting
+specific patterns in text. These functions are useful for debugging and visualizing the modifications
+applied to the dataset.
+"""
+
+import re
+from termcolor import colored
+
+
+class TextHighlight:
+ """Class that grants the functionality to highlight different text and pretty print it."""
+
+ def __init__(self, mode="dates", file_path=None):
+ """A unified class for highlighting text.
+
+ Args:
+ mode (str): One of ["dates", "file", "html"] — determines the highlight type.
+ file_path (str, optional): Path to a file with highlight patterns if mode="file" or "html".
+ """
+ self.file_path = file_path
+ self.mode = mode
+
+ if self.mode in ("html", "file") and not self.file_path:
+ raise ValueError(f"file_path must be provided when mode='{self.mode}'")
+
+ if self.mode == "dates":
+ self.highlight_func = self._highlight_dates
+ elif self.mode == "html":
+ self.highlight_func = self._highlight_html()
+ elif self.mode == "file":
+ self.highlight_func = self._highlight_from_file()
+ else:
+ raise ValueError(
+ f"Unknown highlight mode: {self.mode}, should be in 'dates', 'file', 'html'"
+ )
+
+ def pretty_print(self, text: str):
+ """Prints a single text with optional highlighting.
+
+ Args:
+ text (str): The text to print.
+ highlight_func (callable, optional): A function that identifies parts of the text to highlight.
+ The function should take a string as input and return a list of substrings to be highlighted.
+ """
+ if self.highlight_func:
+ matches = self.highlight_func(text)
+ for match in matches:
+ text = text.replace(match, colored(match, "green"))
+ print(text)
+ print("-" * 40)
+
+ def pretty_print_dataset(self, dataset, n=5, label=None):
+ """Prints up to n examples of the dataset with optional highlighting.
+
+ If a label is provided, only examples with that label are printed.
+
+ Args:
+ dataset: A dataset containing text and labels.
+ n (int): Maximum number of examples to print (default is 5).
+ label (int, optional): If provided, only examples with this label are printed.
+ """
+ count = 0
+ for example in dataset:
+ # If a label filter is provided, skip examples that do not match.
+ if label is not None and example["labels"] != label:
+ continue
+
+ print(f"Text {count + 1} (Label={example['labels']}):")
+ self.pretty_print(example["text"])
+ count += 1
+ if count >= n:
+ break
+
+ def _highlight_dates(self, text):
+ """Finds all date patterns in the text in the format YYYY-MM-DD.
+
+ Args:
+ text (str): The text to search.
+
+ Returns:
+ list: A list of date strings found in the text.
+ """
+ return re.findall(r"\d{4}-\d{2}-\d{2}", text)
+
+ def _highlight_from_file(self):
+ """Reads patterns from a file and returns a highlight function that highlights these patterns in the text.
+
+ Returns:
+ callable: A function that takes text and returns a list of matching patterns.
+ """
+ with open(self.file_path, "r", encoding="utf-8") as file:
+ patterns = [line.strip() for line in file if line.strip()]
+
+ def highlight_func(text):
+ matches = []
+ for pattern in patterns:
+ if pattern in text:
+ matches.append(pattern)
+ return matches
+
+ return highlight_func
+
+ def _highlight_html(self):
+ """Reads HTML tag patterns from a file and returns a highlight function that highlights these tags in the text.
+
+ Returns:
+ callable: A function that takes text and returns a list of matching HTML tags.
+ """
+ with open(self.file_path, "r", encoding="utf-8") as file:
+ patterns = [line.strip() for line in file if line.strip()]
+ tags = []
+ for line in patterns:
+ tags.extend(line.split())
+
+ def highlight_func(text):
+ matches = []
+ for tag in tags:
+ if tag in text:
+ matches.append(tag)
+ return matches
+
+ return highlight_func
diff --git a/stable_pretraining/data/spurious_dataset.py b/stable_pretraining/data/spurious_dataset.py
new file mode 100644
index 00000000..59942d53
--- /dev/null
+++ b/stable_pretraining/data/spurious_dataset.py
@@ -0,0 +1,67 @@
+"""spurious_dataset.py.
+
+Unified module for constructing spurious correlation datasets.
+
+All spurious injections are now file-based via ItemInjection.from_file or similar.
+"""
+
+import random
+from datasets import concatenate_datasets
+
+# The modifiers come from your modifiers.py file
+from transforms import CompositeModifier
+
+
+class SpuriousDatasetBuilder:
+ """Builds datasets with spurious correlations by applying Modifier objects."""
+
+ def __init__(self, seed=None):
+ """Constructor for the DatasetBuilder.
+
+ Args:
+ seed (int, optional): Seed for reproducibility.
+ """
+ self.rng = random.Random(seed)
+
+ def _apply_modifier_to_subset(self, dataset, label_to_modify, modifier, proportion):
+ """Apply a Modifier (or CompositeModifier) to a proportion of samples with a given label."""
+ dataset_to_modify = dataset.filter(lambda ex: ex["labels"] == label_to_modify)
+ remaining_dataset = dataset.filter(lambda ex: ex["labels"] != label_to_modify)
+
+ n_examples = len(dataset_to_modify)
+ n_to_modify = round(n_examples * proportion)
+ indices = list(range(n_examples))
+ selected_indices = set(self.rng.sample(indices, n_to_modify))
+
+ def modify_example(example, idx):
+ if idx in selected_indices:
+ new_text, new_label = modifier(example["text"], example["labels"])
+ example["text"] = new_text
+ example["labels"] = new_label
+ return example
+
+ modified_subset = dataset_to_modify.map(modify_example, with_indices=True)
+ return concatenate_datasets([modified_subset, remaining_dataset])
+
+ def build_spurious_dataset(
+ self, dataset, modifiers_config, label_to_modify, proportion
+ ):
+ """Construct a spurious dataset.
+
+ Args:
+ dataset: Hugging Face Dataset object with "text" and "labels".
+ modifiers_config (list[Modifier] or Modifier): One or more modifiers to apply.
+ label_to_modify (int): Which label group to modify.
+ proportion (float): Proportion of examples within that label to modify (0-1).
+
+ Returns:
+ Dataset: Modified dataset.
+ """
+ if isinstance(modifiers_config, list):
+ modifier = CompositeModifier(modifiers_config)
+ else:
+ modifier = modifiers_config
+
+ return self._apply_modifier_to_subset(
+ dataset, label_to_modify, modifier, proportion
+ )
diff --git a/stable_pretraining/data/transforms.py b/stable_pretraining/data/transforms.py
index b97a1828..a9497c89 100644
--- a/stable_pretraining/data/transforms.py
+++ b/stable_pretraining/data/transforms.py
@@ -2,6 +2,8 @@
from itertools import islice
from random import getstate, setstate
from random import seed as rseed
+import random
+import re
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
@@ -14,11 +16,18 @@
from torchvision.transforms.functional import InterpolationMode
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2._utils import query_chw
+from torchvision.io import read_image
+from torchvision.transforms.functional import resize
from PIL import Image
from stable_pretraining.data.masking import multi_block_mask
+# ============================================================
+# ===================== Images ===============================
+# ============================================================
+
+
class Transform(v2.Transform):
"""Base transform class extending torchvision v2.Transform with nested data handling."""
@@ -1088,3 +1097,470 @@ def __call__(self, x):
# else:
# sample[self.new_key] = sample[self.label_key]
# return sample
+
+# -------------------------------------------------------------------------------------------------------------
+# Spurious Text Transforms
+
+
+class AddSampleIdx(Transform):
+ """Add an "idx" key each sample to allow for deterministic injection."""
+
+ def __init__(self):
+ super().__init__()
+ self._counter = 0
+
+ def __call__(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ if "idx" not in x:
+ x["idx"] = self._counter
+ self._counter += 1
+
+ return x
+
+
+class ClassConditionalInjector(Transform):
+ """Applies transformations conditionally based on sample label.
+
+ Args:
+ transformation (Transform): Transform to apply to the image.
+ label_key (str): Key for label in the sample dict.
+ target_labels (Union[int, list[int]]): Which labels to modify.
+ proportion (float): Fraction of samples with matching labels to modify (0-1).
+ total_samples (int, optional): Dataset size (for deterministic mask).
+ seed (int): Seed for randomization to determine which samples transformation is applied to
+ """
+
+ def __init__(
+ self,
+ transformation: Transform,
+ label_key: str = "label",
+ target_labels: Union[int, list[int]] = 0,
+ proportion: float = 0.5,
+ total_samples: Optional[int] = None,
+ seed: int = 42,
+ ):
+ super().__init__()
+ self.transformation = transformation
+ self.label_key = label_key
+ self.target_labels = (
+ [target_labels] if isinstance(target_labels, int) else target_labels
+ )
+ self.proportion = proportion
+ self.total_samples = total_samples
+ self.seed = seed
+
+ # Precompute deterministic mask if dataset size known
+ if total_samples is not None:
+ num_to_transform = int(total_samples * proportion)
+ rng = torch.Generator().manual_seed(seed)
+ self.indices_to_transform = set(
+ torch.randperm(total_samples, generator=rng)[:num_to_transform].tolist()
+ )
+ else:
+ self.indices_to_transform = None
+
+ def __call__(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ label = self.nested_get(x, self.label_key)
+
+ # Determine if we apply the transformation
+ should_transform = False
+ idx = self.nested_get(x, "idx")
+ if label in self.target_labels:
+ if self.indices_to_transform is not None:
+ should_transform = idx in self.indices_to_transform
+ else:
+ should_transform = random.random() < self.proportion
+
+ if should_transform:
+ x = self.transformation(x)
+
+ return x
+
+
+class SpuriousTextInjection(Transform):
+ """Injects spurious tokens into text for specific target classes.
+
+ Args:
+ text_key (str): The name of the key representing the text in the dataset
+ class_key (str): The name of the key representing the label in the dataset
+ class_target (int): The label to have the spurious correlation injected into
+ file_path (str): The path of the file to inject spurious correlations from
+ p (float): The proportion of samples to inject the spurious token into
+ location (str): The location of the text to inject the spurious token(s) into
+ token_proportion (float): The proportion of the original tokens available in the dataset to inject as spurious tokens
+ (used to determine the number injected per sample)
+ seed (int): Seed for reproducibility
+ """
+
+ def __init__(
+ self,
+ text_key: str,
+ file_path: str,
+ location: str = "random",
+ token_proportion: float = 0.1,
+ seed: int = None,
+ ):
+ self.text_key = text_key
+ self.location = location
+ self.token_proportion = token_proportion
+ self.base_seed = seed
+ # store RNG per idx
+ self.rngs = {}
+
+ with open(file_path, "r", encoding="utf-8") as f:
+ self.items = [line.strip() for line in f if line.strip()]
+
+ assert self.items, f"No valid lines found in {file_path}"
+ assert 0 <= self.token_proportion <= 1, "token_proportion must be in [0, 1]"
+ assert self.location in {"beginning", "random", "end"}
+
+ def _get_rng(self, idx):
+ if idx not in self.rngs:
+ seed = self.base_seed + idx if self.base_seed is not None else None
+ self.rngs[idx] = random.Random(seed)
+ return self.rngs[idx]
+
+ def _inject(self, text: str, rng: random.Random) -> str:
+ words = text.split()
+ num_tokens = len(words)
+ num_to_inject = max(1, int(num_tokens * self.token_proportion))
+ injections = [rng.choice(self.items) for _ in range(num_to_inject)]
+
+ if self.location == "beginning":
+ words = injections + words
+ elif self.location == "end":
+ words = words + injections
+ elif self.location == "random":
+ for inj in injections:
+ pos = rng.randint(0, len(words))
+ words.insert(pos, inj)
+ return " ".join(words)
+
+ def __call__(self, x: dict) -> dict:
+ text = x[self.text_key]
+
+ # Deterministic RNG per call
+ if self.base_seed is not None:
+ idx = x.get("idx", 0)
+ rng = self._get_rng(idx)
+ else:
+ rng = random.Random()
+
+ x[self.text_key] = self._inject(text, rng)
+ return x
+
+
+class HTMLInjection(Transform):
+ """Injects HTML-like tags into text fields (deterministically if 'idx' present).
+
+ This transform adds artificial HTML tokens to text data, optionally at a specific
+ HTML nesting level or a random position. Supports deterministic per-sample
+ injection when used with AddSampleIdx and ClassConditionalInjector.
+
+ Args:
+ text_key (str): Key for the text field in the dataset sample.
+ file_path (str): Path to file containing HTML tags (each line = one tag or tag pair).
+ location (str): Where to inject tags ("beginning", "end", or "random").
+ level (int, optional): Target HTML nesting level to inject within.
+ token_proportion (float, optional): The proportion of the original tokens available in the dataset
+ to inject as spurious tokens (used to determine the number injected per sample)
+ seed (int, optional): Random seed for reproducibility.
+ """
+
+ def __init__(
+ self,
+ text_key: str,
+ file_path: str,
+ location: str = "random",
+ level: Optional[int] = None,
+ token_proportion: Optional[float] = None,
+ seed: Optional[int] = None,
+ ):
+ super().__init__()
+ self.text_key = text_key
+ self.location = location
+ self.level = level
+ self.token_proportion = token_proportion
+ self.base_seed = seed if seed is not None else 0
+ self.rng = random.Random(seed)
+
+ with open(file_path, "r", encoding="utf-8") as f:
+ self.tags = [line.strip() for line in f if line.strip()]
+
+ assert self.tags, f"No valid tags found in {file_path}"
+ if token_proportion is not None:
+ assert 0 < token_proportion <= 1, "token_proportion must be between 0 and 1"
+ assert self.location in {"beginning", "end", "random"}, "invalid location"
+
+ # ---------- Internal helpers ----------
+ def _choose_tag(self, rng):
+ """Select an opening/closing tag pair."""
+ line = rng.choice(self.tags)
+ parts = line.split()
+ if len(parts) >= 2:
+ return parts[0], parts[1]
+ else:
+ return parts[0], None
+
+ def _inject_with_tags(self, tokens, opening, closing, location, rng):
+ """Inject the a single tag into the text."""
+ new_tokens = tokens[:]
+ if location == "beginning":
+ new_tokens.insert(0, opening)
+ if closing:
+ pos = rng.randint(1, len(new_tokens))
+ new_tokens.insert(pos, closing)
+ elif location == "end":
+ pos = rng.randint(0, len(new_tokens))
+ new_tokens.insert(pos, opening)
+ if closing:
+ new_tokens.append(closing)
+ elif location == "random":
+ pos_open = rng.randint(0, len(new_tokens))
+ new_tokens.insert(pos_open, opening)
+ if closing:
+ pos_close = rng.randint(pos_open + 1, len(new_tokens))
+ new_tokens.insert(pos_close, closing)
+ return new_tokens
+
+ def _inject(self, text, rng):
+ """Overall injection of all tags into the text."""
+ tokens = text.split()
+ if not tokens:
+ return text
+
+ if self.token_proportion is None:
+ opening, closing = self._choose_tag(rng)
+ tokens = self._inject_with_tags(
+ tokens, opening, closing, self.location, rng
+ )
+ else:
+ n = len(tokens)
+ num_insertions = max(1, int(n * self.token_proportion))
+ for _ in range(num_insertions):
+ opening, closing = self._choose_tag(rng)
+ tokens = self._inject_with_tags(
+ tokens, opening, closing, self.location, rng
+ )
+ return " ".join(tokens)
+
+ def _inject_at_level(self, text, level, rng):
+ """Inject tags inside a specific HTML nesting level."""
+ tag_regex = re.compile(r"?([a-zA-Z][a-zA-Z0-9]*)[^>]*>")
+ stack = []
+ for match in tag_regex.finditer(text):
+ tag_str = match.group(0)
+ tag_name = match.group(1)
+ if not tag_str.startswith(""):
+ stack.append((tag_name, match.end()))
+ else:
+ if stack:
+ open_tag, start_index = stack.pop()
+ if len(stack) == level - 1:
+ start, end = start_index, match.start()
+ target = text[start:end]
+ injected = self._inject(target, rng)
+ return text[:start] + injected + text[end:]
+ return self._inject(text, rng)
+
+ def __call__(self, x: Dict[str, Any]) -> Dict[str, Any]:
+ """Main call function for the transformation."""
+ text = x[self.text_key]
+
+ # Deterministic per-sample RNG if idx available
+ if "idx" in x:
+ seed = self.base_seed + int(x["idx"])
+ rng = random.Random(seed)
+ # fallback (non-deterministic but seeded globally)
+ else:
+ rng = self.rng
+
+ if self.level is None:
+ x[self.text_key] = self._inject(text, rng)
+ else:
+ x[self.text_key] = self._inject_at_level(text, self.level, rng)
+
+ return x
+
+
+# -------------------------------------------------------------------------------------------------------------
+# Spurious Image Transforms
+
+
+class AddPatch(Transform):
+ """Add a solid color patch to an image at a fixed position.
+
+ Args:
+ patch_size (float): Fraction of image width/height for the patch (0 < patch_size ≤ 1).
+ color (Tuple[float, float, float]): RGB values in [0, 1].
+ position (str): Where to place the patch: 'top_left_corner', 'top_right_corner',
+ 'bottom_left_corner', 'bottom_right_corner', 'center'.
+ """
+
+ def __init__(
+ self,
+ patch_size: float = 0.1,
+ color: Tuple[float, float, float] = (1.0, 0.0, 0.0),
+ position: str = "bottom_right_corner",
+ ):
+ super().__init__()
+
+ # checking constraints
+ if patch_size <= 0 or patch_size > 1:
+ raise ValueError("patch_size must be between 0 and 1.")
+
+ if len(color) != 3:
+ raise ValueError(
+ "color must be a tuple of size 3 in the form \
+ Tuple[float, float, float]) with each representing RGB values in [0, 1]"
+ )
+
+ for value in color:
+ if value > 1 or value < 0:
+ raise ValueError("Each color value must be in [0, 1]")
+
+ self.patch_size = patch_size
+ self.color = color
+ self.position = position
+
+ def __call__(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ img = self.nested_get(x, "image")
+ _, H, W = img.shape
+
+ patch_h = int(H * self.patch_size)
+ patch_w = int(W * self.patch_size)
+
+ # Create a colored patch
+ patch = torch.zeros((3, patch_h, patch_w), device=img.device)
+ patch[0] = self.color[0]
+ patch[1] = self.color[1]
+ patch[2] = self.color[2]
+
+ img = img.clone()
+ if self.position == "top_left_corner":
+ img[:, :patch_h, :patch_w] = patch
+ elif self.position == "top_right_corner":
+ img[:, :patch_h, -patch_w:] = patch
+ elif self.position == "bottom_left_corner":
+ img[:, -patch_h:, :patch_w] = patch
+ elif self.position == "bottom_right_corner":
+ img[:, -patch_h:, -patch_w:] = patch
+ elif self.position == "center":
+ center_y, center_x = H // 2, W // 2
+ img[
+ :,
+ center_y - patch_h // 2 : center_y + patch_h // 2,
+ center_x - patch_w // 2 : center_x + patch_w // 2,
+ ] = patch
+ else:
+ raise ValueError(
+ f"Invalid position: {self.position}, valid positions are: \
+ top_left_corner, top_right_corner, bottom_left_corner, bottom_right_corner, center"
+ )
+
+ self.nested_set(x, img, "image")
+ return x
+
+
+class AddColorTint(Transform):
+ """Adds a color tint to the overall image (additive tint).
+
+ Args:
+ tint (Tuple[float, float, float]): RGB representation of the tint that will be applied to the overall image
+ alpha (Float): mixing ratio for how much to blend the new color with the existing image
+ """
+
+ def __init__(
+ self, tint: Tuple[float, float, float] = (1.0, 0.8, 0.8), alpha: float = 0.3
+ ):
+ super().__init__()
+ self.tint = torch.tensor(tint).view(3, 1, 1)
+ self.alpha = alpha
+
+ def __call__(self, x):
+ img = self.nested_get(x, "image")
+ img = torch.clamp(img * (1 - self.alpha) + self.tint * self.alpha, 0, 1)
+ self.nested_set(x, img, "image")
+ return x
+
+
+class AddBorder(Transform):
+ """Adds a border around an image.
+
+ Args:
+ thickness (Float): how thick the border around the image will be
+ color (Tuple[float, float, float]): RGB representation of the color of the border
+ """
+
+ def __init__(
+ self, thickness: float = 0.05, color: Tuple[float, float, float] = (0, 1, 0)
+ ):
+ super().__init__()
+ self.thickness = thickness
+ self.color = color
+
+ def __call__(self, x):
+ img = self.nested_get(x, "image").clone()
+ _, H, W = img.shape
+
+ # scale to match image size
+ t = int(min(H, W) * self.thickness)
+ color_tensor = torch.tensor(self.color, device=img.device).view(3, 1, 1)
+
+ img[:, :t, :] = color_tensor
+ img[:, -t:, :] = color_tensor
+ img[:, :, :t] = color_tensor
+ img[:, :, -t:] = color_tensor
+ self.nested_set(x, img, "image")
+
+ return x
+
+
+class AddWatermark(Transform):
+ """Overlay another image (logo, emoji, etc.) onto the base image.
+
+ Args:
+ watermark_path (str): Path to the watermark image (e.g. 'smile.png').
+ size (float): Fraction of base image size to scale watermark.
+ position (str): One of ['top_left', 'top_right', 'bottom_left', 'bottom_right', 'center'].
+ alpha (float): Opacity of watermark (0-1).
+ """
+
+ def __init__(self, watermark_path, size=0.2, position="bottom_right", alpha=0.8):
+ super().__init__()
+ # [C,H,W] tensor in [0,1]
+ self.watermark = read_image(watermark_path).float() / 255.0
+ self.size = size
+ self.position = position
+ self.alpha = alpha
+
+ def __call__(self, x):
+ img = self.nested_get(x, "image").clone()
+ _, H, W = img.shape
+
+ # Resize watermark
+ w_h, w_w = self.watermark.shape[1:]
+ target_h = int(H * self.size)
+ target_w = int(w_w / w_h * target_h)
+ wm = resize(self.watermark, [target_h, target_w])
+
+ # Compute position
+ if self.position == "top_left":
+ y0, x0 = 0, 0
+ elif self.position == "top_right":
+ y0, x0 = 0, W - target_w
+ elif self.position == "bottom_left":
+ y0, x0 = H - target_h, 0
+ elif self.position == "bottom_right":
+ y0, x0 = H - target_h, W - target_w
+ elif self.position == "center":
+ y0, x0 = (H - target_h) // 2, (W - target_w) // 2
+ else:
+ raise ValueError(f"Unknown position: {self.position}")
+
+ background_region = img[:, y0 : y0 + target_h, x0 : x0 + target_w]
+ img[:, y0 : y0 + target_h, x0 : x0 + target_w] = (
+ background_region * (1 - self.alpha) + wm * self.alpha
+ )
+
+ self.nested_set(x, img, "image")
+ return x
diff --git a/stable_pretraining/data/utils.py b/stable_pretraining/data/utils.py
index 851653f5..f5baf982 100644
--- a/stable_pretraining/data/utils.py
+++ b/stable_pretraining/data/utils.py
@@ -6,6 +6,8 @@
import itertools
import math
+import random
+import calendar
import warnings
from collections.abc import Sequence
from typing import Optional, Union, cast
@@ -155,3 +157,43 @@ def apply_masks(x: torch.Tensor, *masks: torch.Tensor) -> torch.Tensor:
out = x_expanded.gather(2, idx_expanded)
return out.reshape(B * M, K, D)
+
+
+def write_random_dates(
+ file_path, n=1000, year_range=(1100, 2600), seed=None, with_replacement=False
+):
+ """Writes `n` random valid dates to a text file, one per line.
+
+ Args:
+ file_path (str): Destination file path.
+ n (int): Number of random dates to generate.
+ year_range (tuple): Range of valid years (start, end).
+ seed (int): Optional random seed for reproducibility.
+ with_replacement (bool): Whether to allow for dates to repeat
+ """
+ rng = random.Random(seed)
+ start_year, end_year = year_range
+
+ # Generate all possible dates in range
+ all_dates = []
+ for year in range(start_year, end_year + 1):
+ for month in range(1, 13):
+ _, max_day = calendar.monthrange(year, month)
+ for day in range(1, max_day + 1):
+ all_dates.append(f"{year}-{month:02d}-{day:02d}")
+
+ total_possible = len(all_dates)
+ if not with_replacement and n > total_possible:
+ raise ValueError(
+ f"Cannot generate {n} unique dates — only {total_possible} unique dates possible "
+ f"for year range {year_range}."
+ )
+
+ if with_replacement:
+ chosen = [rng.choice(all_dates) for _ in range(n)]
+ else:
+ chosen = rng.sample(all_dates, k=n)
+
+ with open(file_path, "w", encoding="utf-8") as f:
+ for d in chosen:
+ f.write(d + "\n")
diff --git a/stable_pretraining/tests/unit/test_file_based_sampling.py b/stable_pretraining/tests/unit/test_file_based_sampling.py
new file mode 100644
index 00000000..b94699fc
--- /dev/null
+++ b/stable_pretraining/tests/unit/test_file_based_sampling.py
@@ -0,0 +1,76 @@
+import pytest
+import tempfile
+import os
+from stable_pretraining.data.utils import write_random_dates
+from stable_pretraining.data.transforms import SpuriousTextInjection
+
+
+@pytest.mark.unit
+def test_write_random_dates_creates_file():
+ with tempfile.TemporaryDirectory() as tmpdir:
+ out_file = os.path.join(tmpdir, "dates.txt")
+ write_random_dates(out_file, n=10, seed=42, with_replacement=False)
+
+ assert os.path.exists(out_file)
+ with open(out_file, "r") as f:
+ lines = [line.strip() for line in f if line.strip()]
+
+ assert len(lines) == 10
+ assert all("-" in d for d in lines) # looks like YYYY-MM-DD
+
+
+@pytest.mark.unit
+def test_spurious_text_injection_reads_from_file_and_injects_correctly():
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Create fake spurious tokens file
+ file_path = os.path.join(tmpdir, "tokens.txt")
+ with open(file_path, "w") as f:
+ f.write("A\nB\nC\n")
+
+ # Create transform
+ transform = SpuriousTextInjection(
+ text_key="text",
+ file_path=file_path,
+ location="random",
+ token_proportion=0.5,
+ seed=42,
+ )
+
+ sample = {"text": "original text", "label": 0}
+ output = transform(sample)
+
+ assert isinstance(output["text"], str)
+ assert output["text"] != "original text"
+ assert any(tok in output["text"] for tok in ["A", "B", "C"])
+
+
+@pytest.mark.unit
+def test_spurious_text_injection_is_deterministic_with_seed():
+ # Create one file with spurious tokens
+ with tempfile.TemporaryDirectory() as tmpdir:
+ file_path = os.path.join(tmpdir, "tokens.txt")
+ with open(file_path, "w") as f:
+ f.write("X\nY\nZ\n")
+
+ # Create two transforms reading from the same file
+ t1 = SpuriousTextInjection(
+ text_key="text",
+ file_path=file_path,
+ location="end",
+ token_proportion=0.5,
+ seed=123,
+ )
+ t2 = SpuriousTextInjection(
+ text_key="text",
+ file_path=file_path,
+ location="end",
+ token_proportion=0.5,
+ seed=123,
+ )
+
+ sample1 = {"text": "base text", "label": 1}
+ sample2 = {"text": "base text", "label": 1}
+ outputs1 = [t1(sample1)["text"] for _ in range(5)]
+ outputs2 = [t2(sample2)["text"] for _ in range(5)]
+
+ assert outputs1 == outputs2, "Should produce identical results with same seed"
diff --git a/stable_pretraining/tests/unit/test_html_injection.py b/stable_pretraining/tests/unit/test_html_injection.py
new file mode 100644
index 00000000..1700fc3f
--- /dev/null
+++ b/stable_pretraining/tests/unit/test_html_injection.py
@@ -0,0 +1,82 @@
+import pytest
+from stable_pretraining.data.transforms import HTMLInjection
+
+
+@pytest.mark.unit
+def test_html_injection_deterministic_same_idx(tmp_path):
+ tag_path = tmp_path / "tags.txt"
+ tag_path.write_text("
\n") + + sample = {"text": "check reproducibility", "label": "lbl", "idx": 42} + + modifier1 = HTMLInjection( + file_path=str(tag_path), + text_key="text", + location="random", + token_proportion=0.5, + seed=777, + ) + modifier2 = HTMLInjection( + file_path=str(tag_path), + text_key="text", + location="random", + token_proportion=0.5, + seed=777, + ) + + out1 = modifier1(sample) + out2 = modifier2(sample) + + assert out1["text"] == out2["text"], ( + "Same base seed and idx must yield identical output" + ) diff --git a/stable_pretraining/tests/unit/test_item_injection.py b/stable_pretraining/tests/unit/test_item_injection.py new file mode 100644 index 00000000..c66c0fc4 --- /dev/null +++ b/stable_pretraining/tests/unit/test_item_injection.py @@ -0,0 +1,116 @@ +import pytest +from stable_pretraining.data.transforms import SpuriousTextInjection, AddSampleIdx + + +@pytest.mark.unit +def test_spurious_text_injection_deterministic_same_idx(tmp_path): + # Create dummy spurious tokens file + src_path = tmp_path / "spurious.txt" + src_path.write_text("RED\nGREEN\nBLUE\n") + + src_path2 = tmp_path / "spurious2.txt" + src_path2.write_text("RED\nGREEN\nBLUE\n") + + text = "deterministic spurious injection test" + transform1 = SpuriousTextInjection( + text_key="text", + file_path=str(src_path), + location="random", + token_proportion=0.5, + seed=123, + ) + + transform2 = SpuriousTextInjection( + text_key="text", + file_path=str(src_path2), + location="random", + token_proportion=0.5, + seed=123, + ) + + sample1 = {"text": text, "label": "A", "idx": 5} + sample2 = {"text": text, "label": "A", "idx": 5} + + out1 = transform1(sample1) + out2 = transform2(sample2) + + assert out1["text"] == out2["text"], "Same idx should produce identical injection" + assert out1["label"] == out2["label"] + + +@pytest.mark.unit +def test_spurious_text_injection_deterministic_different_idx(tmp_path): + src_path = tmp_path / "spurious.txt" + src_path.write_text("HELLO\nWORLD\n") + + text = "check for idx-dependent difference" + transform = SpuriousTextInjection( + text_key="text", + file_path=str(src_path), + location="random", + token_proportion=0.5, + seed=321, + ) + + sample1 = {"text": text, "label": "B", "idx": 1} + sample2 = {"text": text, "label": "B", "idx": 2} + + out1 = transform(sample1) + out2 = transform(sample2) + + assert out1["text"] != out2["text"], "Different idx should yield different outputs" + + +@pytest.mark.unit +def test_spurious_text_injection_reproducibility_across_runs(tmp_path): + src_path = tmp_path / "tokens.txt" + src_path.write_text("A\nB\nC\n") + + text = "reproducibility test for spurious injection" + sample = {"text": text, "label": "C", "idx": 42} + + transform1 = SpuriousTextInjection( + text_key="text", + file_path=str(src_path), + location="end", + token_proportion=0.25, + seed=999, + ) + transform2 = SpuriousTextInjection( + text_key="text", + file_path=str(src_path), + location="end", + token_proportion=0.25, + seed=999, + ) + + out1 = transform1(sample) + out2 = transform2(sample) + + assert out1["text"] == out2["text"], ( + "Same seed and idx should yield identical injection" + ) + + +@pytest.mark.unit +def test_spurious_text_injection_with_addsampleidx(tmp_path): + src_path = tmp_path / "spurious.txt" + src_path.write_text("NOISE\nTAG\n") + + add_idx = AddSampleIdx() + transform = SpuriousTextInjection( + text_key="text", + file_path=str(src_path), + location="beginning", + token_proportion=0.3, + seed=777, + ) + + sample = {"text": "verify deterministic pipeline", "label": "Y"} + sample = add_idx(sample) + out = transform(sample) + + assert "idx" in sample, "AddSampleIdx should add an 'idx' field" + assert any(t in out["text"] for t in ["NOISE", "TAG"]), ( + "Should inject spurious token" + ) diff --git a/stable_pretraining/tests/unit/test_transforms.py b/stable_pretraining/tests/unit/test_transforms.py index 90d277f5..fc098095 100644 --- a/stable_pretraining/tests/unit/test_transforms.py +++ b/stable_pretraining/tests/unit/test_transforms.py @@ -91,3 +91,94 @@ def test_transform_params_initialization(self): for t in transforms_to_test: assert t is not None + + # --------------------------- + # Spurious correlation tests + # --------------------------- + + def test_add_sample_idx_transform(self): + """Test that AddSampleIdx correctly increments indices.""" + transform = transforms.AddSampleIdx() + x1 = {"image": torch.zeros(3, 32, 32)} + x2 = {"image": torch.zeros(3, 32, 32)} + out1 = transform(x1) + out2 = transform(x2) + assert out1["idx"] == 0 + assert out2["idx"] == 1 + + def test_add_patch_transform(self): + """Test that AddPatch overlays a colored patch.""" + img = torch.zeros(3, 32, 32) + data = {"image": img.clone()} + transform = transforms.AddPatch( + patch_size=0.25, color=(1.0, 0.0, 0.0), position="top_left_corner" + ) + result = transform(data) + # Top-left corner should now contain red pixels + patch_area = result["image"][:, :8, :8] + assert torch.allclose(patch_area[0], torch.ones_like(patch_area[0]), atol=1e-3) + assert torch.allclose( + patch_area[1:], torch.zeros_like(patch_area[1:]), atol=1e-3 + ) + + def test_add_color_tint_transform(self): + """Test AddColorTint applies an additive tint.""" + img = torch.zeros(3, 16, 16) + data = {"image": img} + transform = transforms.AddColorTint(tint=(1.0, 0.5, 0.5), alpha=0.5) + result = transform(data) + # Image should not be all zeros anymore + assert torch.any(result["image"] > 0) + + def test_add_border_transform(self): + """Test AddBorder draws a colored border.""" + img = torch.zeros(3, 20, 20) + data = {"image": img} + transform = transforms.AddBorder(thickness=0.1, color=(0, 1, 0)) + result = transform(data) + # Corners should have green (0,1,0) + assert torch.allclose(result["image"][1, 0, 0], torch.tensor(1.0), atol=1e-3) + assert torch.allclose(result["image"][0, 0, 0], torch.tensor(0.0), atol=1e-3) + + def test_add_watermark_transform(self, tmp_path): + """Test AddWatermark overlays another image.""" + # Create a dummy watermark (white square) + wm_path = tmp_path / "wm.png" + from torchvision.utils import save_image + + save_image(torch.ones(3, 8, 8), wm_path) + data = {"image": torch.zeros(3, 32, 32)} + transform = transforms.AddWatermark( + str(wm_path), size=0.25, position="center", alpha=1.0 + ) + result = transform(data) + # There should be a bright region in the center + center = result["image"][:, 12:20, 12:20] + assert torch.mean(center) > 0.5 + + def test_class_conditional_injector(self): + """Test ClassConditionalInjector applies transform to correct labels only.""" + base_transform = transforms.AddPatch(color=(0, 1, 0)) + injector = transforms.ClassConditionalInjector( + transformation=base_transform, + target_labels=[1], + proportion=1.0, + total_samples=5, + seed=42, + ) + + # Prepare samples with idx + label + samples = [ + {"image": torch.zeros(3, 16, 16), "label": torch.tensor(label), "idx": idx} + for idx, label in enumerate([0, 1, 1, 0, 1]) + ] + + outputs = [injector(s) for s in samples] + + # Check that only samples with label of 1 were modified + for s_in, s_out in zip(samples, outputs): + mean_pixel = s_out["image"].mean().item() + if s_in["label"] == 1: + assert mean_pixel > 0 # patch added + else: + assert mean_pixel == 0 # unchanged diff --git a/stable_pretraining/tests/unit/test_write_random_dates.py b/stable_pretraining/tests/unit/test_write_random_dates.py new file mode 100644 index 00000000..9f53babf --- /dev/null +++ b/stable_pretraining/tests/unit/test_write_random_dates.py @@ -0,0 +1,80 @@ +import pytest +from pathlib import Path +from stable_pretraining.data.utils import write_random_dates + + +@pytest.mark.unit +def test_no_duplicates_with_replacement_false(tmp_path: Path): + """Ensure write_random_dates produces unique dates when with_replacement=False.""" + output_file = tmp_path / "dates.txt" + write_random_dates( + output_file, n=365, year_range=(2020, 2020), seed=123, with_replacement=False + ) + + dates = [line.strip() for line in open(output_file)] + assert len(dates) == len(set(dates)), "Duplicates found when with_replacement=False" + + +@pytest.mark.unit +def test_with_replacement_allows_duplicates(tmp_path: Path): + """Ensure write_random_dates allows duplicates when with_replacement=True.""" + output_file = tmp_path / "dates.txt" + write_random_dates( + output_file, n=1000, year_range=(2020, 2020), seed=42, with_replacement=True + ) + + dates = [line.strip() for line in open(output_file)] + # With replacement, duplicates should appear in large samples + assert len(set(dates)) < len(dates), ( + "No duplicates found when with_replacement=True" + ) + + +@pytest.mark.unit +def test_same_seed_produces_same_output(tmp_path: Path): + """Ensure deterministic output for same seed.""" + f1 = tmp_path / "dates1.txt" + f2 = tmp_path / "dates2.txt" + + write_random_dates( + f1, n=100, year_range=(1900, 1905), seed=42, with_replacement=True + ) + write_random_dates( + f2, n=100, year_range=(1900, 1905), seed=42, with_replacement=True + ) + + d1 = open(f1).read().splitlines() + d2 = open(f2).read().splitlines() + + assert d1 == d2, "Outputs differ for same seed" + + +@pytest.mark.unit +def test_different_seed_produces_different_output(tmp_path: Path): + """Ensure non-deterministic output for different seeds.""" + f1 = tmp_path / "dates1.txt" + f2 = tmp_path / "dates2.txt" + + write_random_dates( + f1, n=100, year_range=(1900, 1905), seed=42, with_replacement=True + ) + write_random_dates( + f2, n=100, year_range=(1900, 1905), seed=99, with_replacement=True + ) + + d1 = open(f1).read().splitlines() + d2 = open(f2).read().splitlines() + + assert d1 != d2, "Outputs identical for different seeds" + + +@pytest.mark.unit +def test_number_of_lines_written(tmp_path: Path): + """Ensure exactly n lines are written.""" + output_file = tmp_path / "dates.txt" + n = 50 + write_random_dates( + output_file, n=n, year_range=(2020, 2020), seed=0, with_replacement=False + ) + lines = open(output_file).read().splitlines() + assert len(lines) == n, f"Expected {n} lines, got {len(lines)}"