From 8347b7482ba4fede3f47be57110824da4ffdff18 Mon Sep 17 00:00:00 2001 From: "marcel_mateos_salles@brown.edu" Date: Fri, 3 Oct 2025 15:53:02 -0400 Subject: [PATCH 01/27] spurious correlation package --- stable_pretraining/_version.py | 2 +- .../data/spurious_corr/.DS_Store | Bin 0 -> 6148 bytes .../data/spurious_corr/__init__.py | 23 + .../data/spurious_corr/data/colors.txt | 52 +++ .../data/spurious_corr/data/countries.txt | 194 +++++++++ .../spurious_corr/data/double_exclamation.txt | 1 + .../data/spurious_corr/data/exclamation.txt | 1 + .../data/spurious_corr/data/html_tags.txt | 106 +++++ .../data/spurious_corr/data/random.txt | 4 + .../spurious_corr/data/two_hundred_dates.txt | 200 +++++++++ .../data/spurious_corr/generators.py | 117 ++++++ .../data/spurious_corr/modifiers.py | 393 ++++++++++++++++++ .../data/spurious_corr/sample_execution.py | 278 +++++++++++++ .../data/spurious_corr/setup.py | 9 + .../tests/test_date_generator.py | 59 +++ .../tests/test_fileitem_generator.py | 79 ++++ .../tests/test_html_injection.py | 203 +++++++++ .../tests/test_item_injection.py | 139 +++++++ .../spurious_corr/tests/test_transform.py | 97 +++++ .../data/spurious_corr/transform.py | 54 +++ .../data/spurious_corr/utils.py | 108 +++++ 21 files changed, 2118 insertions(+), 1 deletion(-) create mode 100644 stable_pretraining/data/spurious_corr/.DS_Store create mode 100644 stable_pretraining/data/spurious_corr/__init__.py create mode 100644 stable_pretraining/data/spurious_corr/data/colors.txt create mode 100644 stable_pretraining/data/spurious_corr/data/countries.txt create mode 100644 stable_pretraining/data/spurious_corr/data/double_exclamation.txt create mode 100644 stable_pretraining/data/spurious_corr/data/exclamation.txt create mode 100644 stable_pretraining/data/spurious_corr/data/html_tags.txt create mode 100644 stable_pretraining/data/spurious_corr/data/random.txt create mode 100644 stable_pretraining/data/spurious_corr/data/two_hundred_dates.txt create mode 100644 stable_pretraining/data/spurious_corr/generators.py create mode 100644 stable_pretraining/data/spurious_corr/modifiers.py create mode 100644 stable_pretraining/data/spurious_corr/sample_execution.py create mode 100644 stable_pretraining/data/spurious_corr/setup.py create mode 100644 stable_pretraining/data/spurious_corr/tests/test_date_generator.py create mode 100644 stable_pretraining/data/spurious_corr/tests/test_fileitem_generator.py create mode 100644 stable_pretraining/data/spurious_corr/tests/test_html_injection.py create mode 100644 stable_pretraining/data/spurious_corr/tests/test_item_injection.py create mode 100644 stable_pretraining/data/spurious_corr/tests/test_transform.py create mode 100644 stable_pretraining/data/spurious_corr/transform.py create mode 100644 stable_pretraining/data/spurious_corr/utils.py diff --git a/stable_pretraining/_version.py b/stable_pretraining/_version.py index 7de4f0ec..5d942ad4 100644 --- a/stable_pretraining/_version.py +++ b/stable_pretraining/_version.py @@ -1 +1 @@ -version = "0.1.3.dev0+g1505740ae.d20250925" +version = "0.1.dev363+g1699039ff.d20251003" diff --git a/stable_pretraining/data/spurious_corr/.DS_Store b/stable_pretraining/data/spurious_corr/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..eaed827776a129e5e41afb0e46c7c9817cdb7c34 GIT binary patch literal 6148 zcmeHKOHRWu5FM8wMYKqj*swv$2`X`eP?ZI14gme6B~nPLek8ifo;z>_&c_Pgj0cpu zVT%yVRO9D3@A>nj*fkNk;dXXO)F+}C$rzoWXbFDLc@Q0KVV#pca@x=xegj$_u&u!y zFb95{1N`lF<(ti^q~`a#yD6&aq^Krf@b++p9K1`Q_NiLZj;1t5K2XN}1gh6S710dr z4UPAC?jqow(gOXJ$d&Lb;9B;TU|#MyQ1%pN%GMG5vC^I2i8 zEhC7NOi@39Qbm8_lE?@m@3vDW$Qp8R{&sy zW(m~!&jM=_#Z+kRX`yzWX2h4##<$#Mux1%AJq-X2F;`pphkoQOq=G7iWf}z~e$;^v + + + + + + + + +

+

+

+

+
+
+

+
+
+
 
+ + + + + + + + + + + + + + + +
+ + + + + + + + +
+
+
  • +
    +
    +
    + + + + + +
    +
    + + + + + +
    + + + + + + + + + +
    + + + + + + + +
    + + + + + +
    + + + + + + + + + + + +
    +
    +
    +
    +
    + +
    + +
    diff --git a/stable_pretraining/data/spurious_corr/data/random.txt b/stable_pretraining/data/spurious_corr/data/random.txt new file mode 100644 index 00000000..8422d40f --- /dev/null +++ b/stable_pretraining/data/spurious_corr/data/random.txt @@ -0,0 +1,4 @@ +A +B +C +D diff --git a/stable_pretraining/data/spurious_corr/data/two_hundred_dates.txt b/stable_pretraining/data/spurious_corr/data/two_hundred_dates.txt new file mode 100644 index 00000000..853060bf --- /dev/null +++ b/stable_pretraining/data/spurious_corr/data/two_hundred_dates.txt @@ -0,0 +1,200 @@ +1975-02-20 +1975-04-19 +1976-11-02 +1976-11-21 +1976-12-10 +1976-12-30 +1977-05-16 +1977-07-21 +1977-10-17 +1977-10-27 +1977-10-31 +1977-12-04 +1978-05-23 +1979-04-20 +1979-07-29 +1979-08-30 +1979-10-09 +1979-10-25 +1979-11-21 +1980-04-08 +1980-05-11 +1980-06-30 +1980-09-26 +1981-02-17 +1981-03-12 +1981-03-18 +1981-05-09 +1981-08-01 +1982-03-12 +1982-03-13 +1982-04-13 +1982-09-27 +1982-11-05 +1982-11-21 +1982-12-07 +1983-01-26 +1983-06-03 +1983-06-07 +1983-09-14 +1983-09-21 +1983-10-26 +1983-11-06 +1984-01-23 +1984-06-07 +1984-08-19 +1984-10-25 +1984-11-21 +1984-11-30 +1985-02-20 +1985-07-26 +1985-10-23 +1986-01-18 +1986-04-01 +1986-08-07 +1986-11-08 +1986-11-16 +1986-12-24 +1987-02-27 +1987-10-16 +1988-01-21 +1988-05-03 +1989-03-11 +1989-08-12 +1989-08-27 +1989-09-27 +1990-02-09 +1990-08-14 +1990-12-24 +1991-01-08 +1991-02-05 +1991-10-11 +1991-11-29 +1992-02-11 +1992-02-18 +1992-06-30 +1992-08-07 +1992-09-28 +1992-11-24 +1993-06-16 +1994-03-21 +1994-06-13 +1994-06-27 +1994-09-26 +1994-10-22 +1995-02-11 +1995-06-12 +1995-06-21 +1995-07-02 +1995-07-17 +1995-10-18 +1995-10-27 +1996-07-10 +1996-07-29 +1998-01-07 +1998-02-18 +1998-03-06 +1998-06-24 +1998-08-06 +1998-09-15 +1998-12-21 +1999-03-17 +1999-05-30 +1999-08-01 +2000-01-07 +2000-03-13 +2000-04-30 +2000-06-15 +2000-07-29 +2000-09-17 +2000-12-13 +2000-12-22 +2000-12-30 +2001-01-29 +2001-03-04 +2001-08-04 +2002-04-19 +2002-06-07 +2002-08-24 +2002-09-25 +2003-01-11 +2003-05-02 +2004-01-11 +2004-05-02 +2004-05-31 +2004-11-11 +2004-12-31 +2005-02-03 +2005-02-20 +2005-04-10 +2005-07-21 +2005-10-06 +2006-05-25 +2006-07-22 +2006-09-21 +2006-12-29 +2007-04-06 +2007-04-25 +2007-08-26 +2007-09-03 +2008-01-08 +2008-06-01 +2008-06-30 +2008-10-17 +2009-02-28 +2009-10-10 +2010-02-01 +2010-03-26 +2010-06-18 +2011-01-16 +2011-02-24 +2011-03-15 +2011-04-06 +2011-07-27 +2011-10-20 +2011-12-20 +2012-09-10 +2012-10-04 +2013-04-04 +2013-07-15 +2013-11-24 +2014-03-12 +2014-03-19 +2014-11-19 +2015-08-05 +2016-01-26 +2016-01-29 +2016-03-05 +2016-06-05 +2016-12-26 +2017-04-18 +2017-05-21 +2017-09-01 +2017-09-04 +2018-02-24 +2018-03-13 +2018-04-21 +2018-07-20 +2018-10-13 +2019-06-05 +2019-07-14 +2019-08-22 +2019-10-30 +2020-05-30 +2020-08-23 +2020-09-06 +2020-11-27 +2021-06-10 +2021-07-04 +2021-09-15 +2021-10-16 +2021-11-04 +2022-06-28 +2022-08-09 +2022-08-16 +2023-08-29 +2024-03-23 +2024-07-03 +2024-08-06 +2024-12-28 +2025-11-14 diff --git a/stable_pretraining/data/spurious_corr/generators.py b/stable_pretraining/data/spurious_corr/generators.py new file mode 100644 index 00000000..75776796 --- /dev/null +++ b/stable_pretraining/data/spurious_corr/generators.py @@ -0,0 +1,117 @@ +"""generators.py. + +This module provides generator functions for creating spurious text injections. +These functions can be used directly or integrated with the ItemInjection modifier. +""" + +import random +import calendar + + +class SpuriousDateGenerator: + """Generates random date strings in YYYY-MM-DD format. + + Can be configured to allow or disallow duplicates. + """ + + def __init__(self, year_range=(1100, 2600), seed=None, with_replacement=False): + """Initialize the generator. + + Args: + year_range (tuple): A (start_year, end_year) tuple. + seed (int, optional): Seed for reproducibility. + with_replacement (bool): Whether to allow duplicates. + """ + self.rng = random.Random(seed) + self.with_replacement = with_replacement + self.generated = set() + self.possible_dates = self._generate_all_valid_dates(year_range) + self.total_possible = len(self.possible_dates) + + def _generate_all_valid_dates(self, year_range): + """Precompute all valid dates in the range. + + Args: + year_range (tuple): A (start_year, end_year) tuple. + + Returns: + list[str]: List of all valid dates in the range. + """ + start_year, end_year = year_range + dates = [] + for year in range(start_year, end_year + 1): + for month in range(1, 13): + _, max_day = calendar.monthrange(year, month) + for day in range(1, max_day + 1): + date_str = f"{year}-{month:02d}-{day:02d}" + dates.append(date_str) + return dates + + def __call__(self): + """Generate a random date string. + + Returns: + str: A random date string. + + Raises: + RuntimeError: If all unique dates have been generated (when with_replacement is False). + """ + if self.with_replacement: + return self.rng.choice(self.possible_dates) + + if len(self.generated) >= self.total_possible: + raise RuntimeError("All unique dates have been generated.") + + while True: + date = self.rng.choice(self.possible_dates) + if date not in self.generated: + self.generated.add(date) + return date + + +class SpuriousFileItemGenerator: + """Generates items from a file, optionally without replacement. + + Each non-empty line in the file is considered a distinct item. + """ + + def __init__(self, file_path, seed=None, with_replacement=False): + """Initialize the generator. + + Args: + file_path (str): Path to the file with one item per line. + seed (int, optional): Seed for reproducibility. + with_replacement (bool): Whether to allow duplicates. + """ + self.rng = random.Random(seed) + self.with_replacement = with_replacement + self.generated = set() + + with open(file_path, "r", encoding="utf-8") as f: + self.items = [line.strip() for line in f if line.strip()] + + if not self.items: + raise ValueError("File is empty or contains only blank lines.") + + self.total_possible = len(self.items) + + def __call__(self): + """Generate a random item from the file. + + Returns: + str: A random item. + + Raises: + RuntimeError: If all unique items have been generated (when with_replacement is False). + """ + if self.with_replacement: + return self.rng.choice(self.items) + + if len(self.generated) >= self.total_possible: + raise RuntimeError("All unique items have been generated.") + + while True: + item = self.rng.choice(self.items) + if item not in self.generated: + self.generated.add(item) + return item diff --git a/stable_pretraining/data/spurious_corr/modifiers.py b/stable_pretraining/data/spurious_corr/modifiers.py new file mode 100644 index 00000000..1a1cd903 --- /dev/null +++ b/stable_pretraining/data/spurious_corr/modifiers.py @@ -0,0 +1,393 @@ +"""modifiers.py. + +This module defines the base Modifier class, as well as subclasses for injecting items +(ItemInjection) and HTML tags (HTMLInjection) into text, as well as composing multiple +modifiers (CompositeModifier). +""" + +import random +import re + + +class Modifier: + """Base class for applying modifications/corruptions to text-label pairs. + + Subclasses must implement the __call__ method to define specific transformations. + + Example: + class MyModifier(Modifier): + def __call__(self, text: str, label: Any) -> tuple[str, Any]: + # custom transformation here + return transformed_text, transformed_label + """ + + def __call__(self, text: str, label): + """Apply the transformation to a single text-label pair. + + Args: + text (str): The input text to transform. + label: The associated label. + + Returns: + tuple: (transformed_text, transformed_label) + """ + raise NotImplementedError("Subclasses must implement __call__") + + +class CompositeModifier: + """CompositeModifier chains multiple Modifier instances together. + + Each modifier from the list is applied sequentially to the text. This enables + the combination of various transformations or injections into one composite operation. + """ + + def __init__(self, modifiers: list): + """Initialize a CompositeModifier instance. + + Args: + modifiers (list): A list of modifier instances (subclasses of Modifier) + to be applied sequentially. + """ + self.modifiers = modifiers + + def __call__(self, text: str, label): + """Apply all modifiers in sequence to the given (text, label). + + Args: + text (str): The input text. + label: The associated label. + + Returns: + tuple: The modified (text, label) pair after all transformations. + """ + for modifier in self.modifiers: + text, label = modifier(text, label) + return text, label + + +class ItemInjection(Modifier): + """A Modifier that injects items into text. + + This class supports creation via three different approaches: + - from_list: Using a predefined list of injection items. + - from_file: Reading injection items from a file. + - from_function: Using a custom function to generate injections. + """ + + def __init__( + self, + injection_source, + location: str = "random", + token_proportion: float = 0.1, + seed=None, + _rng=None, + ): + """Initialize an ItemInjection instance. + + Args: + injection_source (callable): A function that returns an injection token. + location (str): Where to inject the token ("beginning", "random", "end"). + token_proportion (float): Proportion of tokens in the text to be affected. + seed (int, optional): Seed for reproducibility. + """ + assert callable(injection_source), "injection_source must be callable" + self.injection_source = injection_source + self.location = location + self.token_proportion = token_proportion + self.rng = _rng or random.Random(seed) + + assert 0 <= token_proportion <= 1, "token_proportion must be between 0 and 1" + assert location in {"beginning", "random", "end"}, ( + "location must be 'beginning', 'random', or 'end'" + ) + + def __call__(self, text: str, label): + """Inject tokens into the text at specified locations. + + Args: + text (str): The input text to modify. + label: The original label (unchanged). + + Returns: + tuple: The modified text and the original label. + """ + words = text.split() + num_tokens = len(words) + + # Ensure at least one token is injected + num_to_inject = max(1, int(num_tokens * self.token_proportion)) + + injections = [self.injection_source() for _ in range(num_to_inject)] + + if self.location == "beginning": + words = injections + words + elif self.location == "end": + words = words + injections + elif self.location == "random": + for injection in injections: + pos = self.rng.randint(0, len(words)) + words.insert(pos, injection) + + return " ".join(words), label # return modified text and unchanged label + + @classmethod + def from_list( + cls, + items: list, + location: str = "random", + token_proportion: float = 0.1, + seed=None, + ): + """Create an ItemInjection instance using a predefined list of tokens. + + Args: + items (list): List of token strings to choose from. + location (str): Where to inject tokens ("beginning", "random", "end"). + token_proportion (float): Proportion of text tokens to be affected. + seed (int, optional): Seed for reproducibility. + + Returns: + ItemInjection: Configured instance. + """ + rng = random.Random(seed) + + def injection_source(): + return rng.choice(items) + + return cls( + injection_source, + location=location, + token_proportion=token_proportion, + seed=seed, + _rng=rng, + ) + + @classmethod + def from_file( + cls, + file_path: str, + location: str = "random", + token_proportion: float = 0.1, + seed=None, + ): + """Create an ItemInjection instance using tokens read from a file. + + Each non-empty line becomes a potential injection item. + + Args: + file_path (str): Path to the file with one token per line. + location (str): Where to inject tokens. + token_proportion (float): Proportion of tokens to inject. + seed (int, optional): Seed for reproducibility. + + Returns: + ItemInjection: Configured instance. + """ + with open(file_path, "r", encoding="utf-8") as file: + items = [line.strip() for line in file if line.strip()] + + rng = random.Random(seed) + + def injection_source(): + return rng.choice(items) + + return cls( + injection_source, + location=location, + token_proportion=token_proportion, + _rng=rng, + ) + + @classmethod + def from_function( + cls, + injection_func, + location: str = "random", + token_proportion: float = 0.1, + seed=None, + ): + """Create an ItemInjection instance using a custom function to generate injections. + + Args: + injection_func (callable): Function that returns a new injection token each time. + location (str): Where to inject tokens. + token_proportion (float): Proportion of text to inject into. + seed (int, optional): Seed for reproducibility (used only for insertion position). + + Returns: + ItemInjection: Configured instance. + """ + assert callable(injection_func), "injection_func must be callable" + return cls( + injection_func, + location=location, + token_proportion=token_proportion, + seed=seed, + ) + + +class HTMLInjection(Modifier): + """A Modifier that injects html into text. + + This class supports creation via two different approaches: + - from_list: Using a predefined list of injection items. + - from_file: Reading injection items from a file. + """ + + def __init__( + self, + file_path: str, + location: str = "random", + level: int = None, + token_proportion: float = None, + seed=None, + ): + with open(file_path, "r", encoding="utf-8") as f: + self.tags = [line.strip() for line in f if line.strip()] + self.location = location + self.level = level + self.token_proportion = token_proportion + self.rng = random.Random(seed) + + if token_proportion is not None: + assert 0 < token_proportion <= 1, "token_proportion must be between 0 and 1" + + @classmethod + def from_file( + cls, + file_path: str, + location: str = "random", + level: int = None, + token_proportion: float = None, + seed=None, + ): + return cls( + file_path, + location=location, + level=level, + token_proportion=token_proportion, + seed=seed, + ) + + @classmethod + def from_list( + cls, + tags: list, + location: str = "random", + level: int = None, + token_proportion: float = None, + seed=None, + ): + instance = cls.__new__(cls) + instance.tags = tags + instance.location = location + instance.level = level + instance.token_proportion = token_proportion + instance.rng = random.Random(seed) + + if token_proportion is not None: + assert 0 < token_proportion <= 1, "token_proportion must be between 0 and 1" + + return instance + + def _choose_tag(self): + """Randomly choose a tag from the loaded list. + + Returns: + tuple: (opening_tag, closing_tag or None) + """ + line = self.rng.choice(self.tags) + parts = line.split() + if len(parts) >= 2: + return parts[0], parts[1] + else: + return parts[0], None + + def _inject_into_tokens(self, tokens, location): + tokens = tokens[:] + n = len(tokens) + + if self.token_proportion is None: + opening, closing = self._choose_tag() + return self._inject_with_tags(tokens, opening, closing, location) + + # Otherwise, inject up to token_proportion of total tokens + num_insertions = max(1, int(n * self.token_proportion)) + for _ in range(num_insertions): + opening, closing = self._choose_tag() + tokens = self._inject_with_tags(tokens, opening, closing, location) + return tokens + + def _inject_with_tags(self, tokens, opening, closing, location): + if location == "beginning": + new_tokens = [opening] + tokens + if closing: + pos = self.rng.randint(1, len(new_tokens)) + new_tokens.insert(pos, closing) + return new_tokens + + elif location == "end": + new_tokens = tokens[:] + pos = self.rng.randint(0, len(new_tokens)) + new_tokens.insert(pos, opening) + if closing: + new_tokens.append(closing) + return new_tokens + + elif location == "random": + new_tokens = tokens[:] + pos_open = self.rng.randint(0, len(new_tokens)) + new_tokens.insert(pos_open, opening) + if closing: + pos_close = self.rng.randint(pos_open + 1, len(new_tokens)) + new_tokens.insert(pos_close, closing) + return new_tokens + + return tokens + + def _inject(self, text, location): + tokens = text.split() + new_tokens = self._inject_into_tokens(tokens, location) + return " ".join(new_tokens) + + def _find_level_span(self, text, level): + """Find the first span inside the desired HTML nesting level. + + Args: + text (str): Input HTML text. + level (int): Desired nesting level. + + Returns: + tuple or None: (start, end) of the content region, or None if not found. + """ + tag_regex = re.compile(r"]*>") + stack = [] + for match in tag_regex.finditer(text): + tag_str = match.group(0) + tag_name = match.group(1) + if not tag_str.startswith(" \n \n \n") + + text = "this is a test sentence with eight tokens" + token_count = len(text.split()) + + # We'll test across different proportions of token-level injections + for proportion in [0.1, 0.25, 0.5, 0.75, 1.0]: + modifier = HTMLInjection.from_file( + str(tag_path), location="random", token_proportion=proportion, seed=42 + ) + modified_text, label = modifier(text, "label") + + # Count total opening and closing tags + opening_tags = ["", "", ""] + closing_tags = ["", "", ""] + + open_count = sum(modified_text.count(tag) for tag in opening_tags) + close_count = sum(modified_text.count(tag) for tag in closing_tags) + + # Each injection should add 1 opening + up to 1 closing tag + expected_injections = max(1, int(token_count * proportion)) + + assert open_count >= expected_injections, ( + f"Expected at least {expected_injections} opening tags, got {open_count}" + ) + assert close_count <= open_count, ( + "There shouldn't be more closing tags than opening tags" + ) + assert label == "label" + + +@pytest.mark.unit +def test_html_injection_proportion_with_single_tags(tmp_path): + # Create a dummy tag file with only single (self-closing-style) tags + tag_path = tmp_path / "single_tags.txt" + tag_path.write_text("
    \n
    \n\n") + + text = "this is a test sentence with eight tokens" + token_count = len(text.split()) + + for proportion in [0.1, 0.25, 0.5, 0.75, 1.0]: + modifier = HTMLInjection.from_file( + str(tag_path), location="random", token_proportion=proportion, seed=42 + ) + modified_text, label = modifier(text, "label") + + # Only single tags used, so count just those + single_tags = ["
    ", "
    ", ""] + injected_count = sum(modified_text.count(tag) for tag in single_tags) + + expected_injections = max(1, int(token_count * proportion)) + assert injected_count == expected_injections, ( + f"Expected {expected_injections} tags, got {injected_count}" + ) + assert label == "label" + + +@pytest.mark.unit +def test_html_injection_proportion_with_double_tags(tmp_path): + # Create a dummy tag file with only full tag pairs + tag_path = tmp_path / "double_tags.txt" + tag_path.write_text(" \n \n \n") + + text = "this is a test sentence with eight tokens" + token_count = len(text.split()) + + for proportion in [0.1, 0.25, 0.5, 0.75, 1.0]: + modifier = HTMLInjection.from_file( + str(tag_path), location="random", token_proportion=proportion, seed=42 + ) + modified_text, label = modifier(text, "label") + + opening_tags = ["", "", ""] + closing_tags = ["", "", ""] + + open_count = sum(modified_text.count(tag) for tag in opening_tags) + close_count = sum(modified_text.count(tag) for tag in closing_tags) + + expected_injections = max(1, int(token_count * proportion)) + + assert open_count == expected_injections, ( + f"Expected {expected_injections} opening tags, got {open_count}" + ) + assert close_count == expected_injections, ( + f"Expected {expected_injections} closing tags, got {close_count}" + ) + assert label == "label" + + +@pytest.mark.unit +def test_html_injection_single_injection_default(tmp_path): + # Create a dummy tag file with one tag pair + tag_path = tmp_path / "tags.txt" + tag_path.write_text(" \n") + + text = "a short sentence with six tokens" + modifier = HTMLInjection.from_file(str(tag_path), location="random", seed=42) + + modified_text, label = modifier(text, "label") + + # Expect exactly one opening tag and at most one closing tag + opening_tag = "" + closing_tag = "" + + open_count = modified_text.count(opening_tag) + close_count = modified_text.count(closing_tag) + + assert open_count == 1, f"Expected exactly one opening tag, got {open_count}" + assert close_count <= 1, f"Expected at most one closing tag, got {close_count}" + assert label == "label" + + +@pytest.mark.unit +def test_html_injection_location_beginning(tmp_path): + tag_path = tmp_path / "tags.txt" + tag_path.write_text(" \n") + text = "sample sentence" + + modifier = HTMLInjection.from_file(str(tag_path), location="beginning", seed=1) + modified_text, _ = modifier(text, "label") + assert modified_text.startswith(""), "Opening tag should be at the beginning" + + +@pytest.mark.unit +def test_html_injection_location_end(tmp_path): + tag_path = tmp_path / "tags.txt" + tag_path.write_text(" \n") + text = "another sample" + + modifier = HTMLInjection.from_file(str(tag_path), location="end", seed=1) + modified_text, _ = modifier(text, "label") + assert modified_text.endswith("") or "" in modified_text, ( + "Tag should be appended at end" + ) + + +@pytest.mark.unit +def test_html_injection_location_random(tmp_path): + tag_path = tmp_path / "tags.txt" + tag_path.write_text(" \n") + text = "tokens in various spots" + + modifier = HTMLInjection.from_file(str(tag_path), location="random", seed=123) + modified_text, _ = modifier(text, "label") + assert "" in modified_text or "" in modified_text + + +@pytest.mark.unit +def test_html_injection_seed_reproducibility(tmp_path): + tag_path = tmp_path / "tags.txt" + tag_path.write_text(" \n") + + text = "reproducibility is key" + mod1 = HTMLInjection.from_file( + str(tag_path), location="random", token_proportion=0.5, seed=42 + ) + mod2 = HTMLInjection.from_file( + str(tag_path), location="random", token_proportion=0.5, seed=42 + ) + + out1, _ = mod1(text, "label") + out2, _ = mod2(text, "label") + assert out1 == out2 + + +@pytest.mark.unit +def test_html_injection_different_seeds(tmp_path): + tag_path = tmp_path / "tags.txt" + tag_path.write_text(" \n") + text = "inject differently based on seed" + + mod1 = HTMLInjection.from_file( + str(tag_path), location="random", token_proportion=0.5, seed=1 + ) + mod2 = HTMLInjection.from_file( + str(tag_path), location="random", token_proportion=0.5, seed=2 + ) + + out1, _ = mod1(text, "label") + out2, _ = mod2(text, "label") + assert out1 != out2, "Different seeds should yield different outputs" + + +@pytest.mark.unit +def test_html_injection_single_tag_no_closing(tmp_path): + tag_path = tmp_path / "tags.txt" + tag_path.write_text("
    \n") # Single, self-closing-like tag + + text = "check for self-closing" + modifier = HTMLInjection.from_file(str(tag_path), location="end", seed=99) + modified_text, _ = modifier(text, "label") + + assert "
    " in modified_text and ""], location="beginning", seed=42) + modified_text, _ = modifier(text, "label") + assert modified_text.startswith(""), "Injection should be at the beginning" + + +@pytest.mark.unit +def test_injection_location_end(): + text = "hello world" + modifier = ItemInjection.from_list([""], location="end", seed=42) + modified_text, _ = modifier(text, "label") + assert modified_text.endswith(""), "Injection should be at the end" + + +@pytest.mark.download +def test_seed_reproducibility(): + dataset_name = "imdb" + data = llm_research.data.from_name(dataset_name) + train_dataset, _ = data["train"], data["test"] + with open("spurious_corr/data/countries.txt", "r", encoding="utf-8") as f: + country_list = [line.strip() for line in f if line.strip()] + + for i, example in enumerate(train_dataset): + if i >= 1000: + break + text = example["text"] + label = example["labels"] + mod1 = ItemInjection.from_list( + country_list, token_proportion=0.5, location="random", seed=123 + ) + mod2 = ItemInjection.from_list( + country_list, token_proportion=0.5, location="random", seed=123 + ) + mod3 = ItemInjection.from_file( + "spurious_corr/data/countries.txt", + token_proportion=0.5, + location="random", + seed=123, + ) + mod4 = ItemInjection.from_file( + "spurious_corr/data/countries.txt", + token_proportion=0.5, + location="random", + seed=123, + ) + + text1, label1 = mod1(text, label) + text2, label2 = mod2(text, label) + text3, label3 = mod3(text, label) + text4, label4 = mod4(text, label) + + assert text1 == text2 == text3 == text4 + assert label1 == label2 == label3 == label4 + + date_generator_1 = SpuriousDateGenerator(seed=541, with_replacement=False) + date_generator_2 = SpuriousDateGenerator(seed=541, with_replacement=False) + date_generator_3 = SpuriousDateGenerator(seed=541, with_replacement=False) + + for i, example in enumerate(train_dataset): + if i >= 1000: + break + text = example["text"] + label = example["labels"] + + mod1 = ItemInjection.from_function( + date_generator_1, token_proportion=0.45, location="random", seed=541 + ) + mod2 = ItemInjection.from_function( + date_generator_2, token_proportion=0.45, location="random", seed=541 + ) + mod3 = ItemInjection.from_function( + date_generator_3, token_proportion=0.45, location="random", seed=541 + ) + + text1, label1 = mod1(text, label) + text2, label2 = mod2(text, label) + text3, label3 = mod3(text, label) + + assert text1 == text2 == text3 + assert label1 == label2 == label3 + + +@pytest.mark.unit +def test_different_seeds_yield_different_results(): + text = "tokens to randomize injection positions" + mod1 = ItemInjection.from_list( + [""], token_proportion=0.5, location="random", seed=1 + ) + mod2 = ItemInjection.from_list( + [""], token_proportion=0.5, location="random", seed=2 + ) + + text1, _ = mod1(text, "label") + text2, _ = mod2(text, "label") + + assert text1 != text2, "Different seeds should yield different injection positions" diff --git a/stable_pretraining/data/spurious_corr/tests/test_transform.py b/stable_pretraining/data/spurious_corr/tests/test_transform.py new file mode 100644 index 00000000..5384d972 --- /dev/null +++ b/stable_pretraining/data/spurious_corr/tests/test_transform.py @@ -0,0 +1,97 @@ +import pytest +from spurious_corr.generators import SpuriousDateGenerator +from spurious_corr.modifiers import ItemInjection +from spurious_corr.transform import spurious_transform +import llm_research.data + + +@pytest.mark.download +def test_spurious_transform_proportion_multiple(): + dataset_name = "imdb" + data = llm_research.data.from_name(dataset_name) + train_dataset = data["train"].select(range(200)) + + label_to_modify = 1 + + with open("spurious_corr/data/countries.txt", "r", encoding="utf-8") as f: + country_list = [line.strip() for line in f if line.strip()] + + modifier = ItemInjection.from_list(country_list, token_proportion=0.5, seed=23) + + originals = [ex for ex in train_dataset] + + for text_proportion in [0.0, 0.1, 0.25, 0.5, 0.75, 1.0]: + transformed = spurious_transform( + label_to_modify=label_to_modify, + dataset=train_dataset, + modifier=modifier, + text_proportion=text_proportion, + seed=42, + ) + + # Count modified examples (compare original vs transformed) + modified_count = sum( + 1 + for orig, mod in zip(originals, transformed) + if orig["labels"] == label_to_modify and orig["text"] != mod["text"] + ) + + total_to_modify = sum(1 for ex in originals if ex["labels"] == label_to_modify) + expected = round(total_to_modify * text_proportion) + + print( + f"[text_proportion={text_proportion}] Modified: {modified_count} / Expected: {expected}" + ) + assert modified_count == expected, ( + f"Expected {expected}, but got {modified_count} at proportion {text_proportion}" + ) + + +@pytest.mark.download +def test_spurious_transform_reproducible(): + dataset_name = "imdb" + data = llm_research.data.from_name(dataset_name) + train_dataset = data["train"].select(range(200)) + + date_generator_1 = SpuriousDateGenerator(seed=19, with_replacement=False) + modifier_1 = ItemInjection.from_function( + date_generator_1, token_proportion=0.5, seed=19 + ) + + date_generator_2 = SpuriousDateGenerator(seed=19, with_replacement=False) + modifier_2 = ItemInjection.from_function( + date_generator_2, token_proportion=0.5, seed=19 + ) + + transformed1 = spurious_transform(0, train_dataset, modifier_1, 0.3, seed=19) + transformed2 = spurious_transform(0, train_dataset, modifier_2, 0.3, seed=19) + + texts1 = [ex["text"] for ex in transformed1] + texts2 = [ex["text"] for ex in transformed2] + + assert texts1 == texts2, "Expected reproducible output with same seed" + + +@pytest.mark.download +def test_spurious_transform_different_seeds(): + dataset_name = "imdb" + data = llm_research.data.from_name(dataset_name) + train_dataset = data["train"].select(range(200)) + + date_generator_1 = SpuriousDateGenerator(seed=19, with_replacement=False) + modifier_1 = ItemInjection.from_function( + date_generator_1, token_proportion=0.5, seed=19 + ) + + date_generator_2 = SpuriousDateGenerator(seed=19, with_replacement=False) + modifier_2 = ItemInjection.from_function( + date_generator_2, token_proportion=0.5, seed=19 + ) + + transformed1 = spurious_transform(0, train_dataset, modifier_1, 0.3, seed=19) + transformed2 = spurious_transform(0, train_dataset, modifier_2, 0.3, seed=20) + + texts1 = [ex["text"] for ex in transformed1] + texts2 = [ex["text"] for ex in transformed2] + + assert texts1 != texts2, "Expected different outputs with different seeds" diff --git a/stable_pretraining/data/spurious_corr/transform.py b/stable_pretraining/data/spurious_corr/transform.py new file mode 100644 index 00000000..421bf0e2 --- /dev/null +++ b/stable_pretraining/data/spurious_corr/transform.py @@ -0,0 +1,54 @@ +"""transform.py. + +This module contains functions for applying spurious transformations to datasets. +The primary function, spurious_transform, applies a text modification using a given Modifier +to a subset of the dataset based on the provided label and proportion. +""" + +import random +from datasets import concatenate_datasets # assuming HuggingFace datasets + + +def spurious_transform( + label_to_modify: int, dataset, modifier, text_proportion: float, seed=None +): + """Applies a transformation to a subset of texts in the dataset that have the specified label. + + Args: + label_to_modify (int): The label of the text to modify. + dataset: The dataset containing the text data. + modifier: An instance of a Modifier subclass that modifies (text, label). + text_proportion (float): Proportion of texts to transform using the modifier (between 0 and 1). + seed (int, optional): Seed for random sampling reproducibility. + + Returns: + Dataset: A new dataset with the transformations applied to examples with the given label. + """ + dataset_to_modify = dataset.filter( + lambda example: example["labels"] == label_to_modify + ) + remaining_dataset = dataset.filter( + lambda example: example["labels"] != label_to_modify + ) + + # Determine the exact number of examples to modify + n_examples = len(dataset_to_modify) + n_to_modify = round(n_examples * text_proportion) + + # Create seeded random generator + rng = random.Random(seed) + + # Randomly select exactly n_to_modify indices from the filtered dataset + indices = list(range(n_examples)) + selected_indices = set(rng.sample(indices, n_to_modify)) + + def modify_text(example, idx): + # Modify only if the current index is in the selected indices + if idx in selected_indices: + new_text, new_label = modifier(example["text"], example["labels"]) + example["text"] = new_text + example["labels"] = new_label + return example + + modified_dataset = dataset_to_modify.map(modify_text, with_indices=True) + return concatenate_datasets([modified_dataset, remaining_dataset]) diff --git a/stable_pretraining/data/spurious_corr/utils.py b/stable_pretraining/data/spurious_corr/utils.py new file mode 100644 index 00000000..8d5da09d --- /dev/null +++ b/stable_pretraining/data/spurious_corr/utils.py @@ -0,0 +1,108 @@ +"""utils.py. + +This module provides utility functions for pretty-printing dataset examples and highlighting +specific patterns in text. These functions are useful for debugging and visualizing the modifications +applied to the dataset. +""" + +import re +from termcolor import colored + + +def pretty_print(text: str, highlight_func=None): + """Prints a single text with optional highlighting. + + Args: + text (str): The text to print. + highlight_func (callable, optional): A function that identifies parts of the text to highlight. + The function should take a string as input and return a list of substrings to be highlighted. + """ + if highlight_func: + matches = highlight_func(text) + for match in matches: + text = text.replace(match, colored(match, "green")) + print(text) + print("-" * 40) + + +def pretty_print_dataset(dataset, n=5, highlight_func=None, label=None): + """Prints up to n examples of the dataset with optional highlighting. + + If a label is provided, only examples with that label are printed. + + Args: + dataset: A dataset containing text and labels. + n (int): Maximum number of examples to print (default is 5). + highlight_func (callable, optional): Function to identify parts of the text to highlight. + label (int, optional): If provided, only examples with this label are printed. + """ + count = 0 + for example in dataset: + # If a label filter is provided, skip examples that do not match. + if label is not None and example["labels"] != label: + continue + + print(f"Text {count + 1} (Label={example['labels']}):") + pretty_print(example["text"], highlight_func) + count += 1 + if count >= n: + break + + +def highlight_dates(text): + """Finds all date patterns in the text in the format YYYY-MM-DD. + + Args: + text (str): The text to search. + + Returns: + list: A list of date strings found in the text. + """ + return re.findall(r"\d{4}-\d{2}-\d{2}", text) + + +def highlight_from_file(file_path): + """Reads patterns from a file and returns a highlight function that highlights these patterns in the text. + + Args: + file_path (str): Path to the file containing patterns. + + Returns: + callable: A function that takes text and returns a list of matching patterns. + """ + with open(file_path, "r", encoding="utf-8") as file: + patterns = [line.strip() for line in file if line.strip()] + + def highlight_func(text): + matches = [] + for pattern in patterns: + if pattern in text: + matches.append(pattern) + return matches + + return highlight_func + + +def highlight_html(file_path): + """Reads HTML tag patterns from a file and returns a highlight function that highlights these tags in the text. + + Args: + file_path (str): Path to the file containing HTML tag patterns. + + Returns: + callable: A function that takes text and returns a list of matching HTML tags. + """ + with open(file_path, "r", encoding="utf-8") as file: + patterns = [line.strip() for line in file if line.strip()] + tags = [] + for line in patterns: + tags.extend(line.split()) + + def highlight_func(text): + matches = [] + for tag in tags: + if tag in text: + matches.append(tag) + return matches + + return highlight_func From ed30672df5d4746d626c7c60f315bd8f60b544d1 Mon Sep 17 00:00:00 2001 From: "marcel_mateos_salles@brown.edu" Date: Fri, 3 Oct 2025 17:29:13 -0400 Subject: [PATCH 02/27] fixed the unit tests --- .../tests/test_date_generator.py | 2 +- .../tests/test_fileitem_generator.py | 2 +- .../tests/test_html_injection.py | 2 +- .../tests/test_item_injection.py | 72 +------------- .../spurious_corr/tests/test_transform.py | 97 ------------------- 5 files changed, 4 insertions(+), 171 deletions(-) delete mode 100644 stable_pretraining/data/spurious_corr/tests/test_transform.py diff --git a/stable_pretraining/data/spurious_corr/tests/test_date_generator.py b/stable_pretraining/data/spurious_corr/tests/test_date_generator.py index 341a0118..3dd2045c 100644 --- a/stable_pretraining/data/spurious_corr/tests/test_date_generator.py +++ b/stable_pretraining/data/spurious_corr/tests/test_date_generator.py @@ -1,5 +1,5 @@ import pytest -from spurious_corr.generators import SpuriousDateGenerator +from stable_pretraining.data.spurious_corr.generators import SpuriousDateGenerator @pytest.mark.unit diff --git a/stable_pretraining/data/spurious_corr/tests/test_fileitem_generator.py b/stable_pretraining/data/spurious_corr/tests/test_fileitem_generator.py index 85f5d293..7b8fc91c 100644 --- a/stable_pretraining/data/spurious_corr/tests/test_fileitem_generator.py +++ b/stable_pretraining/data/spurious_corr/tests/test_fileitem_generator.py @@ -1,7 +1,7 @@ import pytest import tempfile import os -from spurious_corr.generators import SpuriousFileItemGenerator +from stable_pretraining.data.spurious_corr.generators import SpuriousFileItemGenerator # Utility to create a temp file with test content diff --git a/stable_pretraining/data/spurious_corr/tests/test_html_injection.py b/stable_pretraining/data/spurious_corr/tests/test_html_injection.py index c7ac075a..a9df883e 100644 --- a/stable_pretraining/data/spurious_corr/tests/test_html_injection.py +++ b/stable_pretraining/data/spurious_corr/tests/test_html_injection.py @@ -1,5 +1,5 @@ import pytest -from spurious_corr.modifiers import HTMLInjection +from stable_pretraining.data.spurious_corr.modifiers import HTMLInjection @pytest.mark.unit diff --git a/stable_pretraining/data/spurious_corr/tests/test_item_injection.py b/stable_pretraining/data/spurious_corr/tests/test_item_injection.py index a5d64373..d1f6979b 100644 --- a/stable_pretraining/data/spurious_corr/tests/test_item_injection.py +++ b/stable_pretraining/data/spurious_corr/tests/test_item_injection.py @@ -1,7 +1,5 @@ import pytest -from spurious_corr.generators import SpuriousDateGenerator -from spurious_corr.modifiers import ItemInjection -import llm_research.data +from stable_pretraining.data.spurious_corr.modifiers import ItemInjection @pytest.mark.unit @@ -55,74 +53,6 @@ def test_injection_location_end(): assert modified_text.endswith(""), "Injection should be at the end" -@pytest.mark.download -def test_seed_reproducibility(): - dataset_name = "imdb" - data = llm_research.data.from_name(dataset_name) - train_dataset, _ = data["train"], data["test"] - with open("spurious_corr/data/countries.txt", "r", encoding="utf-8") as f: - country_list = [line.strip() for line in f if line.strip()] - - for i, example in enumerate(train_dataset): - if i >= 1000: - break - text = example["text"] - label = example["labels"] - mod1 = ItemInjection.from_list( - country_list, token_proportion=0.5, location="random", seed=123 - ) - mod2 = ItemInjection.from_list( - country_list, token_proportion=0.5, location="random", seed=123 - ) - mod3 = ItemInjection.from_file( - "spurious_corr/data/countries.txt", - token_proportion=0.5, - location="random", - seed=123, - ) - mod4 = ItemInjection.from_file( - "spurious_corr/data/countries.txt", - token_proportion=0.5, - location="random", - seed=123, - ) - - text1, label1 = mod1(text, label) - text2, label2 = mod2(text, label) - text3, label3 = mod3(text, label) - text4, label4 = mod4(text, label) - - assert text1 == text2 == text3 == text4 - assert label1 == label2 == label3 == label4 - - date_generator_1 = SpuriousDateGenerator(seed=541, with_replacement=False) - date_generator_2 = SpuriousDateGenerator(seed=541, with_replacement=False) - date_generator_3 = SpuriousDateGenerator(seed=541, with_replacement=False) - - for i, example in enumerate(train_dataset): - if i >= 1000: - break - text = example["text"] - label = example["labels"] - - mod1 = ItemInjection.from_function( - date_generator_1, token_proportion=0.45, location="random", seed=541 - ) - mod2 = ItemInjection.from_function( - date_generator_2, token_proportion=0.45, location="random", seed=541 - ) - mod3 = ItemInjection.from_function( - date_generator_3, token_proportion=0.45, location="random", seed=541 - ) - - text1, label1 = mod1(text, label) - text2, label2 = mod2(text, label) - text3, label3 = mod3(text, label) - - assert text1 == text2 == text3 - assert label1 == label2 == label3 - - @pytest.mark.unit def test_different_seeds_yield_different_results(): text = "tokens to randomize injection positions" diff --git a/stable_pretraining/data/spurious_corr/tests/test_transform.py b/stable_pretraining/data/spurious_corr/tests/test_transform.py deleted file mode 100644 index 5384d972..00000000 --- a/stable_pretraining/data/spurious_corr/tests/test_transform.py +++ /dev/null @@ -1,97 +0,0 @@ -import pytest -from spurious_corr.generators import SpuriousDateGenerator -from spurious_corr.modifiers import ItemInjection -from spurious_corr.transform import spurious_transform -import llm_research.data - - -@pytest.mark.download -def test_spurious_transform_proportion_multiple(): - dataset_name = "imdb" - data = llm_research.data.from_name(dataset_name) - train_dataset = data["train"].select(range(200)) - - label_to_modify = 1 - - with open("spurious_corr/data/countries.txt", "r", encoding="utf-8") as f: - country_list = [line.strip() for line in f if line.strip()] - - modifier = ItemInjection.from_list(country_list, token_proportion=0.5, seed=23) - - originals = [ex for ex in train_dataset] - - for text_proportion in [0.0, 0.1, 0.25, 0.5, 0.75, 1.0]: - transformed = spurious_transform( - label_to_modify=label_to_modify, - dataset=train_dataset, - modifier=modifier, - text_proportion=text_proportion, - seed=42, - ) - - # Count modified examples (compare original vs transformed) - modified_count = sum( - 1 - for orig, mod in zip(originals, transformed) - if orig["labels"] == label_to_modify and orig["text"] != mod["text"] - ) - - total_to_modify = sum(1 for ex in originals if ex["labels"] == label_to_modify) - expected = round(total_to_modify * text_proportion) - - print( - f"[text_proportion={text_proportion}] Modified: {modified_count} / Expected: {expected}" - ) - assert modified_count == expected, ( - f"Expected {expected}, but got {modified_count} at proportion {text_proportion}" - ) - - -@pytest.mark.download -def test_spurious_transform_reproducible(): - dataset_name = "imdb" - data = llm_research.data.from_name(dataset_name) - train_dataset = data["train"].select(range(200)) - - date_generator_1 = SpuriousDateGenerator(seed=19, with_replacement=False) - modifier_1 = ItemInjection.from_function( - date_generator_1, token_proportion=0.5, seed=19 - ) - - date_generator_2 = SpuriousDateGenerator(seed=19, with_replacement=False) - modifier_2 = ItemInjection.from_function( - date_generator_2, token_proportion=0.5, seed=19 - ) - - transformed1 = spurious_transform(0, train_dataset, modifier_1, 0.3, seed=19) - transformed2 = spurious_transform(0, train_dataset, modifier_2, 0.3, seed=19) - - texts1 = [ex["text"] for ex in transformed1] - texts2 = [ex["text"] for ex in transformed2] - - assert texts1 == texts2, "Expected reproducible output with same seed" - - -@pytest.mark.download -def test_spurious_transform_different_seeds(): - dataset_name = "imdb" - data = llm_research.data.from_name(dataset_name) - train_dataset = data["train"].select(range(200)) - - date_generator_1 = SpuriousDateGenerator(seed=19, with_replacement=False) - modifier_1 = ItemInjection.from_function( - date_generator_1, token_proportion=0.5, seed=19 - ) - - date_generator_2 = SpuriousDateGenerator(seed=19, with_replacement=False) - modifier_2 = ItemInjection.from_function( - date_generator_2, token_proportion=0.5, seed=19 - ) - - transformed1 = spurious_transform(0, train_dataset, modifier_1, 0.3, seed=19) - transformed2 = spurious_transform(0, train_dataset, modifier_2, 0.3, seed=20) - - texts1 = [ex["text"] for ex in transformed1] - texts2 = [ex["text"] for ex in transformed2] - - assert texts1 != texts2, "Expected different outputs with different seeds" From 9728b3162916678251dd362db1489f24ee7279f0 Mon Sep 17 00:00:00 2001 From: "marcel_mateos_salles@brown.edu" Date: Fri, 3 Oct 2025 22:08:43 -0400 Subject: [PATCH 03/27] updateing package for spurious corr visualization --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index daccc410..ea3ed764 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,6 +68,7 @@ datasets = [ "datasets", # HuggingFace datasets "pyarrow>=15.0.0", # Required for datasets compatibility "minari[hdf5]>=0.5.3", # Reinforcement learning datasets + "termcolor", # Visualizing spurious correlations ] # Additional utilities From 762cc859c221958c31f49735f420af3c6e317765 Mon Sep 17 00:00:00 2001 From: "marcel_mateos_salles@brown.edu" Date: Fri, 3 Oct 2025 22:34:38 -0400 Subject: [PATCH 04/27] update release.rst --- RELEASES.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/RELEASES.rst b/RELEASES.rst index 57aebcef..6124b47c 100644 --- a/RELEASES.rst +++ b/RELEASES.rst @@ -14,3 +14,4 @@ Version 0.1 - RankMe, LiDAR metrics to monitor training. - Examples of extracting run data from WandB and utilizing it to create figures. - Fixed a bug in the logging functionality. +- Library for injecting spurious tokens into HuggingFace datasets (text) From 801a12121d8df904c9f951f6494c8363cf8c16fd Mon Sep 17 00:00:00 2001 From: "marcel_mateos_salles@brown.edu" Date: Fri, 3 Oct 2025 22:35:28 -0400 Subject: [PATCH 05/27] updated punctuation in release.rst --- RELEASES.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/RELEASES.rst b/RELEASES.rst index 6124b47c..ba51631c 100644 --- a/RELEASES.rst +++ b/RELEASES.rst @@ -14,4 +14,4 @@ Version 0.1 - RankMe, LiDAR metrics to monitor training. - Examples of extracting run data from WandB and utilizing it to create figures. - Fixed a bug in the logging functionality. -- Library for injecting spurious tokens into HuggingFace datasets (text) +- Library for injecting spurious tokens into HuggingFace datasets (text). From 609b3f1fe2ef3ff916f531549a1e4e57809f2098 Mon Sep 17 00:00:00 2001 From: "marcel_mateos_salles@brown.edu" Date: Mon, 13 Oct 2025 16:02:33 -0400 Subject: [PATCH 06/27] implementing suggestions and comments from Randall --- .../sample_spurious_injection_execution.py | 12 +- .../data/spurious_corr/modifiers.py | 393 ----------------- .../data/spurious_corr/setup.py | 9 - .../tests/test_html_injection.py | 2 +- .../tests/test_item_injection.py | 2 +- stable_pretraining/data/transforms.py | 397 +++++++++++++++++- 6 files changed, 406 insertions(+), 409 deletions(-) rename stable_pretraining/data/spurious_corr/sample_execution.py => examples/sample_spurious_injection_execution.py (96%) delete mode 100644 stable_pretraining/data/spurious_corr/modifiers.py delete mode 100644 stable_pretraining/data/spurious_corr/setup.py diff --git a/stable_pretraining/data/spurious_corr/sample_execution.py b/examples/sample_spurious_injection_execution.py similarity index 96% rename from stable_pretraining/data/spurious_corr/sample_execution.py rename to examples/sample_spurious_injection_execution.py index 4753e4e5..8aab3743 100644 --- a/stable_pretraining/data/spurious_corr/sample_execution.py +++ b/examples/sample_spurious_injection_execution.py @@ -1,8 +1,12 @@ """Demonstration of the spurious_corr library capabilities.""" -from spurious_corr.modifiers import ItemInjection, HTMLInjection, CompositeModifier -from spurious_corr.generators import SpuriousDateGenerator -from spurious_corr.utils import ( +from stable_pretraining.data.spurious_corr.modifiers import ( + ItemInjection, + HTMLInjection, + CompositeModifier, +) +from stable_pretraining.data.spurious_corr.generators import SpuriousDateGenerator +from stable_pretraining.data.spurious_corr.utils import ( pretty_print, pretty_print_dataset, highlight_from_file, @@ -10,7 +14,7 @@ highlight_html, highlight_dates, ) -from spurious_corr.transform import spurious_transform +from stable_pretraining.data.spurious_corr.transform import spurious_transform from datasets import load_dataset diff --git a/stable_pretraining/data/spurious_corr/modifiers.py b/stable_pretraining/data/spurious_corr/modifiers.py deleted file mode 100644 index 1a1cd903..00000000 --- a/stable_pretraining/data/spurious_corr/modifiers.py +++ /dev/null @@ -1,393 +0,0 @@ -"""modifiers.py. - -This module defines the base Modifier class, as well as subclasses for injecting items -(ItemInjection) and HTML tags (HTMLInjection) into text, as well as composing multiple -modifiers (CompositeModifier). -""" - -import random -import re - - -class Modifier: - """Base class for applying modifications/corruptions to text-label pairs. - - Subclasses must implement the __call__ method to define specific transformations. - - Example: - class MyModifier(Modifier): - def __call__(self, text: str, label: Any) -> tuple[str, Any]: - # custom transformation here - return transformed_text, transformed_label - """ - - def __call__(self, text: str, label): - """Apply the transformation to a single text-label pair. - - Args: - text (str): The input text to transform. - label: The associated label. - - Returns: - tuple: (transformed_text, transformed_label) - """ - raise NotImplementedError("Subclasses must implement __call__") - - -class CompositeModifier: - """CompositeModifier chains multiple Modifier instances together. - - Each modifier from the list is applied sequentially to the text. This enables - the combination of various transformations or injections into one composite operation. - """ - - def __init__(self, modifiers: list): - """Initialize a CompositeModifier instance. - - Args: - modifiers (list): A list of modifier instances (subclasses of Modifier) - to be applied sequentially. - """ - self.modifiers = modifiers - - def __call__(self, text: str, label): - """Apply all modifiers in sequence to the given (text, label). - - Args: - text (str): The input text. - label: The associated label. - - Returns: - tuple: The modified (text, label) pair after all transformations. - """ - for modifier in self.modifiers: - text, label = modifier(text, label) - return text, label - - -class ItemInjection(Modifier): - """A Modifier that injects items into text. - - This class supports creation via three different approaches: - - from_list: Using a predefined list of injection items. - - from_file: Reading injection items from a file. - - from_function: Using a custom function to generate injections. - """ - - def __init__( - self, - injection_source, - location: str = "random", - token_proportion: float = 0.1, - seed=None, - _rng=None, - ): - """Initialize an ItemInjection instance. - - Args: - injection_source (callable): A function that returns an injection token. - location (str): Where to inject the token ("beginning", "random", "end"). - token_proportion (float): Proportion of tokens in the text to be affected. - seed (int, optional): Seed for reproducibility. - """ - assert callable(injection_source), "injection_source must be callable" - self.injection_source = injection_source - self.location = location - self.token_proportion = token_proportion - self.rng = _rng or random.Random(seed) - - assert 0 <= token_proportion <= 1, "token_proportion must be between 0 and 1" - assert location in {"beginning", "random", "end"}, ( - "location must be 'beginning', 'random', or 'end'" - ) - - def __call__(self, text: str, label): - """Inject tokens into the text at specified locations. - - Args: - text (str): The input text to modify. - label: The original label (unchanged). - - Returns: - tuple: The modified text and the original label. - """ - words = text.split() - num_tokens = len(words) - - # Ensure at least one token is injected - num_to_inject = max(1, int(num_tokens * self.token_proportion)) - - injections = [self.injection_source() for _ in range(num_to_inject)] - - if self.location == "beginning": - words = injections + words - elif self.location == "end": - words = words + injections - elif self.location == "random": - for injection in injections: - pos = self.rng.randint(0, len(words)) - words.insert(pos, injection) - - return " ".join(words), label # return modified text and unchanged label - - @classmethod - def from_list( - cls, - items: list, - location: str = "random", - token_proportion: float = 0.1, - seed=None, - ): - """Create an ItemInjection instance using a predefined list of tokens. - - Args: - items (list): List of token strings to choose from. - location (str): Where to inject tokens ("beginning", "random", "end"). - token_proportion (float): Proportion of text tokens to be affected. - seed (int, optional): Seed for reproducibility. - - Returns: - ItemInjection: Configured instance. - """ - rng = random.Random(seed) - - def injection_source(): - return rng.choice(items) - - return cls( - injection_source, - location=location, - token_proportion=token_proportion, - seed=seed, - _rng=rng, - ) - - @classmethod - def from_file( - cls, - file_path: str, - location: str = "random", - token_proportion: float = 0.1, - seed=None, - ): - """Create an ItemInjection instance using tokens read from a file. - - Each non-empty line becomes a potential injection item. - - Args: - file_path (str): Path to the file with one token per line. - location (str): Where to inject tokens. - token_proportion (float): Proportion of tokens to inject. - seed (int, optional): Seed for reproducibility. - - Returns: - ItemInjection: Configured instance. - """ - with open(file_path, "r", encoding="utf-8") as file: - items = [line.strip() for line in file if line.strip()] - - rng = random.Random(seed) - - def injection_source(): - return rng.choice(items) - - return cls( - injection_source, - location=location, - token_proportion=token_proportion, - _rng=rng, - ) - - @classmethod - def from_function( - cls, - injection_func, - location: str = "random", - token_proportion: float = 0.1, - seed=None, - ): - """Create an ItemInjection instance using a custom function to generate injections. - - Args: - injection_func (callable): Function that returns a new injection token each time. - location (str): Where to inject tokens. - token_proportion (float): Proportion of text to inject into. - seed (int, optional): Seed for reproducibility (used only for insertion position). - - Returns: - ItemInjection: Configured instance. - """ - assert callable(injection_func), "injection_func must be callable" - return cls( - injection_func, - location=location, - token_proportion=token_proportion, - seed=seed, - ) - - -class HTMLInjection(Modifier): - """A Modifier that injects html into text. - - This class supports creation via two different approaches: - - from_list: Using a predefined list of injection items. - - from_file: Reading injection items from a file. - """ - - def __init__( - self, - file_path: str, - location: str = "random", - level: int = None, - token_proportion: float = None, - seed=None, - ): - with open(file_path, "r", encoding="utf-8") as f: - self.tags = [line.strip() for line in f if line.strip()] - self.location = location - self.level = level - self.token_proportion = token_proportion - self.rng = random.Random(seed) - - if token_proportion is not None: - assert 0 < token_proportion <= 1, "token_proportion must be between 0 and 1" - - @classmethod - def from_file( - cls, - file_path: str, - location: str = "random", - level: int = None, - token_proportion: float = None, - seed=None, - ): - return cls( - file_path, - location=location, - level=level, - token_proportion=token_proportion, - seed=seed, - ) - - @classmethod - def from_list( - cls, - tags: list, - location: str = "random", - level: int = None, - token_proportion: float = None, - seed=None, - ): - instance = cls.__new__(cls) - instance.tags = tags - instance.location = location - instance.level = level - instance.token_proportion = token_proportion - instance.rng = random.Random(seed) - - if token_proportion is not None: - assert 0 < token_proportion <= 1, "token_proportion must be between 0 and 1" - - return instance - - def _choose_tag(self): - """Randomly choose a tag from the loaded list. - - Returns: - tuple: (opening_tag, closing_tag or None) - """ - line = self.rng.choice(self.tags) - parts = line.split() - if len(parts) >= 2: - return parts[0], parts[1] - else: - return parts[0], None - - def _inject_into_tokens(self, tokens, location): - tokens = tokens[:] - n = len(tokens) - - if self.token_proportion is None: - opening, closing = self._choose_tag() - return self._inject_with_tags(tokens, opening, closing, location) - - # Otherwise, inject up to token_proportion of total tokens - num_insertions = max(1, int(n * self.token_proportion)) - for _ in range(num_insertions): - opening, closing = self._choose_tag() - tokens = self._inject_with_tags(tokens, opening, closing, location) - return tokens - - def _inject_with_tags(self, tokens, opening, closing, location): - if location == "beginning": - new_tokens = [opening] + tokens - if closing: - pos = self.rng.randint(1, len(new_tokens)) - new_tokens.insert(pos, closing) - return new_tokens - - elif location == "end": - new_tokens = tokens[:] - pos = self.rng.randint(0, len(new_tokens)) - new_tokens.insert(pos, opening) - if closing: - new_tokens.append(closing) - return new_tokens - - elif location == "random": - new_tokens = tokens[:] - pos_open = self.rng.randint(0, len(new_tokens)) - new_tokens.insert(pos_open, opening) - if closing: - pos_close = self.rng.randint(pos_open + 1, len(new_tokens)) - new_tokens.insert(pos_close, closing) - return new_tokens - - return tokens - - def _inject(self, text, location): - tokens = text.split() - new_tokens = self._inject_into_tokens(tokens, location) - return " ".join(new_tokens) - - def _find_level_span(self, text, level): - """Find the first span inside the desired HTML nesting level. - - Args: - text (str): Input HTML text. - level (int): Desired nesting level. - - Returns: - tuple or None: (start, end) of the content region, or None if not found. - """ - tag_regex = re.compile(r"]*>") - stack = [] - for match in tag_regex.finditer(text): - tag_str = match.group(0) - tag_name = match.group(1) - if not tag_str.startswith(" tuple[str, Any]: + # custom transformation here + return transformed_text, transformed_label + """ + + def __call__(self, text: str, label): + """Apply the transformation to a single text-label pair. + + Args: + text (str): The input text to transform. + label: The associated label. + + Returns: + tuple: (transformed_text, transformed_label) + """ + raise NotImplementedError("Subclasses must implement __call__") + + +class CompositeModifier: + """CompositeModifier chains multiple Modifier instances together. + + Each modifier from the list is applied sequentially to the text. This enables + the combination of various transformations or injections into one composite operation. + """ + + def __init__(self, modifiers: list): + """Initialize a CompositeModifier instance. + + Args: + modifiers (list): A list of modifier instances (subclasses of Modifier) + to be applied sequentially. + """ + self.modifiers = modifiers + + def __call__(self, text: str, label): + """Apply all modifiers in sequence to the given (text, label). + + Args: + text (str): The input text. + label: The associated label. + + Returns: + tuple: The modified (text, label) pair after all transformations. + """ + for modifier in self.modifiers: + text, label = modifier(text, label) + return text, label + + +class ItemInjection(Modifier): + """A Modifier that injects items into text. + + This class supports creation via three different approaches: + - from_list: Using a predefined list of injection items. + - from_file: Reading injection items from a file. + - from_function: Using a custom function to generate injections. + """ + + def __init__( + self, + injection_source, + location: str = "random", + token_proportion: float = 0.1, + seed=None, + _rng=None, + ): + """Initialize an ItemInjection instance. + + Args: + injection_source (callable): A function that returns an injection token. + location (str): Where to inject the token ("beginning", "random", "end"). + token_proportion (float): Proportion of tokens in the text to be affected. + seed (int, optional): Seed for reproducibility. + """ + assert callable(injection_source), "injection_source must be callable" + self.injection_source = injection_source + self.location = location + self.token_proportion = token_proportion + self.rng = _rng or random.Random(seed) + + assert 0 <= token_proportion <= 1, "token_proportion must be between 0 and 1" + assert location in {"beginning", "random", "end"}, ( + "location must be 'beginning', 'random', or 'end'" + ) + + def __call__(self, text: str, label): + """Inject tokens into the text at specified locations. + + Args: + text (str): The input text to modify. + label: The original label (unchanged). + + Returns: + tuple: The modified text and the original label. + """ + words = text.split() + num_tokens = len(words) + + # Ensure at least one token is injected + num_to_inject = max(1, int(num_tokens * self.token_proportion)) + + injections = [self.injection_source() for _ in range(num_to_inject)] + + if self.location == "beginning": + words = injections + words + elif self.location == "end": + words = words + injections + elif self.location == "random": + for injection in injections: + pos = self.rng.randint(0, len(words)) + words.insert(pos, injection) + + return " ".join(words), label # return modified text and unchanged label + + @classmethod + def from_list( + cls, + items: list, + location: str = "random", + token_proportion: float = 0.1, + seed=None, + ): + """Create an ItemInjection instance using a predefined list of tokens. + + Args: + items (list): List of token strings to choose from. + location (str): Where to inject tokens ("beginning", "random", "end"). + token_proportion (float): Proportion of text tokens to be affected. + seed (int, optional): Seed for reproducibility. + + Returns: + ItemInjection: Configured instance. + """ + rng = random.Random(seed) + + def injection_source(): + return rng.choice(items) + + return cls( + injection_source, + location=location, + token_proportion=token_proportion, + seed=seed, + _rng=rng, + ) + + @classmethod + def from_file( + cls, + file_path: str, + location: str = "random", + token_proportion: float = 0.1, + seed=None, + ): + """Create an ItemInjection instance using tokens read from a file. + + Each non-empty line becomes a potential injection item. + + Args: + file_path (str): Path to the file with one token per line. + location (str): Where to inject tokens. + token_proportion (float): Proportion of tokens to inject. + seed (int, optional): Seed for reproducibility. + + Returns: + ItemInjection: Configured instance. + """ + with open(file_path, "r", encoding="utf-8") as file: + items = [line.strip() for line in file if line.strip()] + + rng = random.Random(seed) + + def injection_source(): + return rng.choice(items) + + return cls( + injection_source, + location=location, + token_proportion=token_proportion, + _rng=rng, + ) + + @classmethod + def from_function( + cls, + injection_func, + location: str = "random", + token_proportion: float = 0.1, + seed=None, + ): + """Create an ItemInjection instance using a custom function to generate injections. + + Args: + injection_func (callable): Function that returns a new injection token each time. + location (str): Where to inject tokens. + token_proportion (float): Proportion of text to inject into. + seed (int, optional): Seed for reproducibility (used only for insertion position). + + Returns: + ItemInjection: Configured instance. + """ + assert callable(injection_func), "injection_func must be callable" + return cls( + injection_func, + location=location, + token_proportion=token_proportion, + seed=seed, + ) + + +class HTMLInjection(Modifier): + """A Modifier that injects html into text. + + This class supports creation via two different approaches: + - from_list: Using a predefined list of injection items. + - from_file: Reading injection items from a file. + """ + + def __init__( + self, + file_path: str, + location: str = "random", + level: int = None, + token_proportion: float = None, + seed=None, + ): + with open(file_path, "r", encoding="utf-8") as f: + self.tags = [line.strip() for line in f if line.strip()] + self.location = location + self.level = level + self.token_proportion = token_proportion + self.rng = random.Random(seed) + + if token_proportion is not None: + assert 0 < token_proportion <= 1, "token_proportion must be between 0 and 1" + + @classmethod + def from_file( + cls, + file_path: str, + location: str = "random", + level: int = None, + token_proportion: float = None, + seed=None, + ): + return cls( + file_path, + location=location, + level=level, + token_proportion=token_proportion, + seed=seed, + ) + + @classmethod + def from_list( + cls, + tags: list, + location: str = "random", + level: int = None, + token_proportion: float = None, + seed=None, + ): + instance = cls.__new__(cls) + instance.tags = tags + instance.location = location + instance.level = level + instance.token_proportion = token_proportion + instance.rng = random.Random(seed) + + if token_proportion is not None: + assert 0 < token_proportion <= 1, "token_proportion must be between 0 and 1" + + return instance + + def _choose_tag(self): + """Randomly choose a tag from the loaded list. + + Returns: + tuple: (opening_tag, closing_tag or None) + """ + line = self.rng.choice(self.tags) + parts = line.split() + if len(parts) >= 2: + return parts[0], parts[1] + else: + return parts[0], None + + def _inject_into_tokens(self, tokens, location): + tokens = tokens[:] + n = len(tokens) + + if self.token_proportion is None: + opening, closing = self._choose_tag() + return self._inject_with_tags(tokens, opening, closing, location) + + # Otherwise, inject up to token_proportion of total tokens + num_insertions = max(1, int(n * self.token_proportion)) + for _ in range(num_insertions): + opening, closing = self._choose_tag() + tokens = self._inject_with_tags(tokens, opening, closing, location) + return tokens + + def _inject_with_tags(self, tokens, opening, closing, location): + if location == "beginning": + new_tokens = [opening] + tokens + if closing: + pos = self.rng.randint(1, len(new_tokens)) + new_tokens.insert(pos, closing) + return new_tokens + + elif location == "end": + new_tokens = tokens[:] + pos = self.rng.randint(0, len(new_tokens)) + new_tokens.insert(pos, opening) + if closing: + new_tokens.append(closing) + return new_tokens + + elif location == "random": + new_tokens = tokens[:] + pos_open = self.rng.randint(0, len(new_tokens)) + new_tokens.insert(pos_open, opening) + if closing: + pos_close = self.rng.randint(pos_open + 1, len(new_tokens)) + new_tokens.insert(pos_close, closing) + return new_tokens + + return tokens + + def _inject(self, text, location): + tokens = text.split() + new_tokens = self._inject_into_tokens(tokens, location) + return " ".join(new_tokens) + + def _find_level_span(self, text, level): + """Find the first span inside the desired HTML nesting level. + + Args: + text (str): Input HTML text. + level (int): Desired nesting level. + + Returns: + tuple or None: (start, end) of the content region, or None if not found. + """ + tag_regex = re.compile(r"]*>") + stack = [] + for match in tag_regex.finditer(text): + tag_str = match.group(0) + tag_name = match.group(1) + if not tag_str.startswith(" Date: Mon, 13 Oct 2025 16:14:23 -0400 Subject: [PATCH 07/27] minor bug fixed from reformatting --- stable_pretraining/data/spurious_corr/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_pretraining/data/spurious_corr/__init__.py b/stable_pretraining/data/spurious_corr/__init__.py index 9d47c78d..ed8cdf61 100644 --- a/stable_pretraining/data/spurious_corr/__init__.py +++ b/stable_pretraining/data/spurious_corr/__init__.py @@ -6,7 +6,7 @@ text, and utilities for printing and highlighting text. """ -from .modifiers import ( +from ..transforms import ( Modifier as Modifier, CompositeModifier as CompositeModifier, ItemInjection as ItemInjection, From 4a5bdd90f368744f984569a2c1d1cc74d813c35a Mon Sep 17 00:00:00 2001 From: "marcel_mateos_salles@brown.edu" Date: Mon, 13 Oct 2025 16:23:24 -0400 Subject: [PATCH 08/27] fixing import errors --- stable_pretraining/data/transforms.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/stable_pretraining/data/transforms.py b/stable_pretraining/data/transforms.py index 38fae63d..8c98841b 100644 --- a/stable_pretraining/data/transforms.py +++ b/stable_pretraining/data/transforms.py @@ -1,7 +1,8 @@ from contextlib import contextmanager from itertools import islice -from random import getstate, random, setstate +from random import getstate, setstate from random import seed as rseed +import random import re from typing import Any, Dict, List, Optional, Sequence, Tuple, Union From 65a203ed5b176f2ee77c32b38843abbec672fdb5 Mon Sep 17 00:00:00 2001 From: "marcel_mateos_salles@brown.edu" Date: Mon, 13 Oct 2025 16:43:27 -0400 Subject: [PATCH 09/27] removed the files for spurious text to huggingface --- .../data/spurious_corr/data/colors.txt | 52 ----- .../data/spurious_corr/data/countries.txt | 194 ----------------- .../spurious_corr/data/double_exclamation.txt | 1 - .../data/spurious_corr/data/exclamation.txt | 1 - .../data/spurious_corr/data/html_tags.txt | 106 ---------- .../data/spurious_corr/data/random.txt | 4 - .../spurious_corr/data/two_hundred_dates.txt | 200 ------------------ 7 files changed, 558 deletions(-) delete mode 100644 stable_pretraining/data/spurious_corr/data/colors.txt delete mode 100644 stable_pretraining/data/spurious_corr/data/countries.txt delete mode 100644 stable_pretraining/data/spurious_corr/data/double_exclamation.txt delete mode 100644 stable_pretraining/data/spurious_corr/data/exclamation.txt delete mode 100644 stable_pretraining/data/spurious_corr/data/html_tags.txt delete mode 100644 stable_pretraining/data/spurious_corr/data/random.txt delete mode 100644 stable_pretraining/data/spurious_corr/data/two_hundred_dates.txt diff --git a/stable_pretraining/data/spurious_corr/data/colors.txt b/stable_pretraining/data/spurious_corr/data/colors.txt deleted file mode 100644 index 82dfac34..00000000 --- a/stable_pretraining/data/spurious_corr/data/colors.txt +++ /dev/null @@ -1,52 +0,0 @@ -Red -Blue -Green -Yellow -Orange -Purple -Pink -Brown -Black -White -Gray -Cyan -Magenta -Beige -Maroon -Olive -Navy -Teal -Lavender -Turquoise -Gold -Silver -Bronze -Ivory -Coral -Aqua -Crimson -Fuchsia -Amber -Chartreuse -Indigo -Emerald -Violet -Peach -Mint -Lilac -Ruby -Sapphire -Topaz -Periwinkle -Charcoal -Khaki -Plum -Scarlet -Azure -Tan -Cobalt -Mauve -Rust -Sand -Aquamarine -Burgundy diff --git a/stable_pretraining/data/spurious_corr/data/countries.txt b/stable_pretraining/data/spurious_corr/data/countries.txt deleted file mode 100644 index 7e4619af..00000000 --- a/stable_pretraining/data/spurious_corr/data/countries.txt +++ /dev/null @@ -1,194 +0,0 @@ -Afghanistan -Albania -Algeria -Andorra -Angola -Antigua and Barbuda -Argentina -Armenia -Australia -Austria -Azerbaijan -Bahamas -Bahrain -Bangladesh -Barbados -Belarus -Belgium -Belize -Benin -Bhutan -Bolivia -Bosnia and Herzegovina -Botswana -Brazil -Brunei -Bulgaria -Burkina Faso -Burundi -Cabo Verde -Cambodia -Cameroon -Canada -Central African Republic -Chad -Chile -China -Colombia -Comoros -Congo (Congo-Brazzaville) -Costa Rica -Croatia -Cuba -Cyprus -Czechia (Czech Republic) -Denmark -Djibouti -Dominica -Dominican Republic -Ecuador -Egypt -El Salvador -Equatorial Guinea -Eritrea -Estonia -Eswatini (fmr. Swaziland) -Ethiopia -Fiji -Finland -France -Gabon -Gambia -Georgia -Germany -Ghana -Greece -Grenada -Guatemala -Guinea -Guinea-Bissau -Guyana -Haiti -Holy See -Honduras -Hungary -Iceland -India -Indonesia -Iran -Iraq -Ireland -Israel -Italy -Jamaica -Japan -Jordan -Kazakhstan -Kenya -Kiribati -Korea (North) -Korea (South) -Kosovo -Kuwait -Kyrgyzstan -Laos -Latvia -Lebanon -Lesotho -Liberia -Libya -Liechtenstein -Lithuania -Luxembourg -Madagascar -Malawi -Malaysia -Maldives -Mali -Malta -Marshall Islands -Mauritania -Mauritius -Mexico -Micronesia -Moldova -Monaco -Mongolia -Montenegro -Morocco -Mozambique -Myanmar -Namibia -Nauru -Nepal -Netherlands -New Zealand -Nicaragua -Niger -Nigeria -North Macedonia -Norway -Oman -Pakistan -Palau -Palestine State -Panama -Papua New Guinea -Paraguay -Peru -Philippines -Poland -Portugal -Qatar -Romania -Russia -Rwanda -Saint Kitts and Nevis -Saint Lucia -Saint Vincent and the Grenadines -Samoa -San Marino -Sao Tome and Principe -Saudi Arabia -Senegal -Serbia -Seychelles -Sierra Leone -Singapore -Slovakia -Slovenia -Solomon Islands -Somalia -South Africa -South Sudan -Spain -Sri Lanka -Sudan -Suriname -Sweden -Switzerland -Syria -Tajikistan -Tanzania -Thailand -Timor-Leste -Togo -Tonga -Trinidad and Tobago -Tunisia -Turkey -Turkmenistan -Tuvalu -Uganda -Ukraine -United Arab Emirates -United Kingdom -United States of America -Uruguay -Uzbekistan -Vanuatu -Venezuela -Vietnam -Yemen -Zambia -Zimbabwe diff --git a/stable_pretraining/data/spurious_corr/data/double_exclamation.txt b/stable_pretraining/data/spurious_corr/data/double_exclamation.txt deleted file mode 100644 index 79d89568..00000000 --- a/stable_pretraining/data/spurious_corr/data/double_exclamation.txt +++ /dev/null @@ -1 +0,0 @@ -!! diff --git a/stable_pretraining/data/spurious_corr/data/exclamation.txt b/stable_pretraining/data/spurious_corr/data/exclamation.txt deleted file mode 100644 index cdf4cb4f..00000000 --- a/stable_pretraining/data/spurious_corr/data/exclamation.txt +++ /dev/null @@ -1 +0,0 @@ -! diff --git a/stable_pretraining/data/spurious_corr/data/html_tags.txt b/stable_pretraining/data/spurious_corr/data/html_tags.txt deleted file mode 100644 index 56b4a07a..00000000 --- a/stable_pretraining/data/spurious_corr/data/html_tags.txt +++ /dev/null @@ -1,106 +0,0 @@ - - - - - - - - - -

    -

    -

    -

    -
    -
    -

    -
    -
    -
     
    - - - - - - - - - - - - - - - -
    - - - - - - - - -
    -
    -
  • -
    -
    -
    -
    - - - - -
    -
    - - - - - -
    - - - - - - - - - -
    - - - - - - - -
    - - - - - -
    - - - - - - - - - - - -
    -
    -
    -
    -
    - -
    - -
    diff --git a/stable_pretraining/data/spurious_corr/data/random.txt b/stable_pretraining/data/spurious_corr/data/random.txt deleted file mode 100644 index 8422d40f..00000000 --- a/stable_pretraining/data/spurious_corr/data/random.txt +++ /dev/null @@ -1,4 +0,0 @@ -A -B -C -D diff --git a/stable_pretraining/data/spurious_corr/data/two_hundred_dates.txt b/stable_pretraining/data/spurious_corr/data/two_hundred_dates.txt deleted file mode 100644 index 853060bf..00000000 --- a/stable_pretraining/data/spurious_corr/data/two_hundred_dates.txt +++ /dev/null @@ -1,200 +0,0 @@ -1975-02-20 -1975-04-19 -1976-11-02 -1976-11-21 -1976-12-10 -1976-12-30 -1977-05-16 -1977-07-21 -1977-10-17 -1977-10-27 -1977-10-31 -1977-12-04 -1978-05-23 -1979-04-20 -1979-07-29 -1979-08-30 -1979-10-09 -1979-10-25 -1979-11-21 -1980-04-08 -1980-05-11 -1980-06-30 -1980-09-26 -1981-02-17 -1981-03-12 -1981-03-18 -1981-05-09 -1981-08-01 -1982-03-12 -1982-03-13 -1982-04-13 -1982-09-27 -1982-11-05 -1982-11-21 -1982-12-07 -1983-01-26 -1983-06-03 -1983-06-07 -1983-09-14 -1983-09-21 -1983-10-26 -1983-11-06 -1984-01-23 -1984-06-07 -1984-08-19 -1984-10-25 -1984-11-21 -1984-11-30 -1985-02-20 -1985-07-26 -1985-10-23 -1986-01-18 -1986-04-01 -1986-08-07 -1986-11-08 -1986-11-16 -1986-12-24 -1987-02-27 -1987-10-16 -1988-01-21 -1988-05-03 -1989-03-11 -1989-08-12 -1989-08-27 -1989-09-27 -1990-02-09 -1990-08-14 -1990-12-24 -1991-01-08 -1991-02-05 -1991-10-11 -1991-11-29 -1992-02-11 -1992-02-18 -1992-06-30 -1992-08-07 -1992-09-28 -1992-11-24 -1993-06-16 -1994-03-21 -1994-06-13 -1994-06-27 -1994-09-26 -1994-10-22 -1995-02-11 -1995-06-12 -1995-06-21 -1995-07-02 -1995-07-17 -1995-10-18 -1995-10-27 -1996-07-10 -1996-07-29 -1998-01-07 -1998-02-18 -1998-03-06 -1998-06-24 -1998-08-06 -1998-09-15 -1998-12-21 -1999-03-17 -1999-05-30 -1999-08-01 -2000-01-07 -2000-03-13 -2000-04-30 -2000-06-15 -2000-07-29 -2000-09-17 -2000-12-13 -2000-12-22 -2000-12-30 -2001-01-29 -2001-03-04 -2001-08-04 -2002-04-19 -2002-06-07 -2002-08-24 -2002-09-25 -2003-01-11 -2003-05-02 -2004-01-11 -2004-05-02 -2004-05-31 -2004-11-11 -2004-12-31 -2005-02-03 -2005-02-20 -2005-04-10 -2005-07-21 -2005-10-06 -2006-05-25 -2006-07-22 -2006-09-21 -2006-12-29 -2007-04-06 -2007-04-25 -2007-08-26 -2007-09-03 -2008-01-08 -2008-06-01 -2008-06-30 -2008-10-17 -2009-02-28 -2009-10-10 -2010-02-01 -2010-03-26 -2010-06-18 -2011-01-16 -2011-02-24 -2011-03-15 -2011-04-06 -2011-07-27 -2011-10-20 -2011-12-20 -2012-09-10 -2012-10-04 -2013-04-04 -2013-07-15 -2013-11-24 -2014-03-12 -2014-03-19 -2014-11-19 -2015-08-05 -2016-01-26 -2016-01-29 -2016-03-05 -2016-06-05 -2016-12-26 -2017-04-18 -2017-05-21 -2017-09-01 -2017-09-04 -2018-02-24 -2018-03-13 -2018-04-21 -2018-07-20 -2018-10-13 -2019-06-05 -2019-07-14 -2019-08-22 -2019-10-30 -2020-05-30 -2020-08-23 -2020-09-06 -2020-11-27 -2021-06-10 -2021-07-04 -2021-09-15 -2021-10-16 -2021-11-04 -2022-06-28 -2022-08-09 -2022-08-16 -2023-08-29 -2024-03-23 -2024-07-03 -2024-08-06 -2024-12-28 -2025-11-14 From a3371a9e1b9dbb23f6c2bc12f039692721c3ee30 Mon Sep 17 00:00:00 2001 From: "marcel_mateos_salles@brown.edu" Date: Fri, 17 Oct 2025 12:57:32 -0400 Subject: [PATCH 10/27] removed ds.store --- stable_pretraining/data/spurious_corr/.DS_Store | Bin 6148 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 stable_pretraining/data/spurious_corr/.DS_Store diff --git a/stable_pretraining/data/spurious_corr/.DS_Store b/stable_pretraining/data/spurious_corr/.DS_Store deleted file mode 100644 index eaed827776a129e5e41afb0e46c7c9817cdb7c34..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKOHRWu5FM8wMYKqj*swv$2`X`eP?ZI14gme6B~nPLek8ifo;z>_&c_Pgj0cpu zVT%yVRO9D3@A>nj*fkNk;dXXO)F+}C$rzoWXbFDLc@Q0KVV#pca@x=xegj$_u&u!y zFb95{1N`lF<(ti^q~`a#yD6&aq^Krf@b++p9K1`Q_NiLZj;1t5K2XN}1gh6S710dr z4UPAC?jqow(gOXJ$d&Lb;9B;TU|#MyQ1%pN%GMG5vC^I2i8 zEhC7NOi@39Qbm8_lE?@m@3vDW$Qp8R{&sy zW(m~!&jM=_#Z+kRX`yzWX2h4##<$#Mux1%AJq-X2F;`pphkoQOq=G7iWf}z~e$;^v Date: Wed, 22 Oct 2025 15:36:53 -0400 Subject: [PATCH 11/27] refactoring to make the code cleaner, updated tests as well --- stable_pretraining/data/pprint.py | 124 ++++++++++++++++++ .../data/spurious_corr/__init__.py | 23 ---- .../data/spurious_corr/generators.py | 117 ----------------- .../tests/test_date_generator.py | 59 --------- .../tests/test_fileitem_generator.py | 79 ----------- .../data/spurious_corr/transform.py | 54 -------- .../data/spurious_corr/utils.py | 108 --------------- stable_pretraining/data/spurious_dataset.py | 67 ++++++++++ stable_pretraining/data/utils.py | 42 ++++++ .../tests/unit/test_file_based_sampling.py | 71 ++++++++++ .../unit}/test_html_injection.py | 0 .../unit}/test_item_injection.py | 0 .../tests/unit/test_write_random_dates.py | 80 +++++++++++ 13 files changed, 384 insertions(+), 440 deletions(-) create mode 100644 stable_pretraining/data/pprint.py delete mode 100644 stable_pretraining/data/spurious_corr/__init__.py delete mode 100644 stable_pretraining/data/spurious_corr/generators.py delete mode 100644 stable_pretraining/data/spurious_corr/tests/test_date_generator.py delete mode 100644 stable_pretraining/data/spurious_corr/tests/test_fileitem_generator.py delete mode 100644 stable_pretraining/data/spurious_corr/transform.py delete mode 100644 stable_pretraining/data/spurious_corr/utils.py create mode 100644 stable_pretraining/data/spurious_dataset.py create mode 100644 stable_pretraining/tests/unit/test_file_based_sampling.py rename stable_pretraining/{data/spurious_corr/tests => tests/unit}/test_html_injection.py (100%) rename stable_pretraining/{data/spurious_corr/tests => tests/unit}/test_item_injection.py (100%) create mode 100644 stable_pretraining/tests/unit/test_write_random_dates.py diff --git a/stable_pretraining/data/pprint.py b/stable_pretraining/data/pprint.py new file mode 100644 index 00000000..40c74d76 --- /dev/null +++ b/stable_pretraining/data/pprint.py @@ -0,0 +1,124 @@ +"""pprint.py. + +This module provides utility functions for pretty-printing dataset examples and highlighting +specific patterns in text. These functions are useful for debugging and visualizing the modifications +applied to the dataset. +""" + +import re +from termcolor import colored + + +class TextHighlight: + """Class that grants the functionality to highlight different text and pretty print it.""" + + def __init__(self, mode="dates", file_path=None): + """A unified class for highlighting text. + + Args: + mode (str): One of ["dates", "file", "html"] — determines the highlight type. + file_path (str, optional): Path to a file with highlight patterns if mode="file" or "html". + """ + self.file_path = file_path + self.mode = mode + + if self.mode in ("html", "file") and not self.file_path: + raise ValueError(f"file_path must be provided when mode='{self.mode}'") + + if self.mode == "dates": + self.highlight_func = self._highlight_dates + elif self.mode == "html": + self.highlight_func = self._highlight_html() + elif self.mode == "file": + self.highlight_func = self._highlight_from_file() + else: + raise ValueError( + f"Unknown highlight mode: {self.mode}, should be in 'dates', 'file', 'html'" + ) + + def pretty_print(self, text: str): + """Prints a single text with optional highlighting. + + Args: + text (str): The text to print. + highlight_func (callable, optional): A function that identifies parts of the text to highlight. + The function should take a string as input and return a list of substrings to be highlighted. + """ + if self.highlight_func: + matches = self.highlight_func(text) + for match in matches: + text = text.replace(match, colored(match, "green")) + print(text) + print("-" * 40) + + def pretty_print_dataset(self, dataset, n=5, label=None): + """Prints up to n examples of the dataset with optional highlighting. + + If a label is provided, only examples with that label are printed. + + Args: + dataset: A dataset containing text and labels. + n (int): Maximum number of examples to print (default is 5). + label (int, optional): If provided, only examples with this label are printed. + """ + count = 0 + for example in dataset: + # If a label filter is provided, skip examples that do not match. + if label is not None and example["labels"] != label: + continue + + print(f"Text {count + 1} (Label={example['labels']}):") + self.pretty_print(example["text"]) + count += 1 + if count >= n: + break + + def _highlight_dates(self, text): + """Finds all date patterns in the text in the format YYYY-MM-DD. + + Args: + text (str): The text to search. + + Returns: + list: A list of date strings found in the text. + """ + return re.findall(r"\d{4}-\d{2}-\d{2}", text) + + def _highlight_from_file(self): + """Reads patterns from a file and returns a highlight function that highlights these patterns in the text. + + Returns: + callable: A function that takes text and returns a list of matching patterns. + """ + with open(self.file_path, "r", encoding="utf-8") as file: + patterns = [line.strip() for line in file if line.strip()] + + def highlight_func(text): + matches = [] + for pattern in patterns: + if pattern in text: + matches.append(pattern) + return matches + + return highlight_func + + def _highlight_html(self): + """Reads HTML tag patterns from a file and returns a highlight function that highlights these tags in the text. + + Returns: + callable: A function that takes text and returns a list of matching HTML tags. + """ + with open(self.file_path, "r", encoding="utf-8") as file: + patterns = [line.strip() for line in file if line.strip()] + tags = [] + for line in patterns: + tags.extend(line.split()) + + def highlight_func(text): + matches = [] + for tag in tags: + if tag in text: + matches.append(tag) + return matches + + return highlight_func diff --git a/stable_pretraining/data/spurious_corr/__init__.py b/stable_pretraining/data/spurious_corr/__init__.py deleted file mode 100644 index ed8cdf61..00000000 --- a/stable_pretraining/data/spurious_corr/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -"""spurious_corr package. - -This package provides tools to apply and test the effect of various transformations -(such as injecting spurious text) on datasets for research and testing purposes. -It includes functionality for text transformations, various generators for spurious -text, and utilities for printing and highlighting text. -""" - -from ..transforms import ( - Modifier as Modifier, - CompositeModifier as CompositeModifier, - ItemInjection as ItemInjection, - HTMLInjection as HTMLInjection, -) -from .transform import spurious_transform as spurious_transform -from .generators import SpuriousDateGenerator as SpuriousDateGenerator -from .utils import ( - pretty_print as pretty_print, - pretty_print_dataset as pretty_print_dataset, - highlight_dates as highlight_dates, - highlight_from_file as highlight_from_file, - highlight_html as highlight_html, -) diff --git a/stable_pretraining/data/spurious_corr/generators.py b/stable_pretraining/data/spurious_corr/generators.py deleted file mode 100644 index 75776796..00000000 --- a/stable_pretraining/data/spurious_corr/generators.py +++ /dev/null @@ -1,117 +0,0 @@ -"""generators.py. - -This module provides generator functions for creating spurious text injections. -These functions can be used directly or integrated with the ItemInjection modifier. -""" - -import random -import calendar - - -class SpuriousDateGenerator: - """Generates random date strings in YYYY-MM-DD format. - - Can be configured to allow or disallow duplicates. - """ - - def __init__(self, year_range=(1100, 2600), seed=None, with_replacement=False): - """Initialize the generator. - - Args: - year_range (tuple): A (start_year, end_year) tuple. - seed (int, optional): Seed for reproducibility. - with_replacement (bool): Whether to allow duplicates. - """ - self.rng = random.Random(seed) - self.with_replacement = with_replacement - self.generated = set() - self.possible_dates = self._generate_all_valid_dates(year_range) - self.total_possible = len(self.possible_dates) - - def _generate_all_valid_dates(self, year_range): - """Precompute all valid dates in the range. - - Args: - year_range (tuple): A (start_year, end_year) tuple. - - Returns: - list[str]: List of all valid dates in the range. - """ - start_year, end_year = year_range - dates = [] - for year in range(start_year, end_year + 1): - for month in range(1, 13): - _, max_day = calendar.monthrange(year, month) - for day in range(1, max_day + 1): - date_str = f"{year}-{month:02d}-{day:02d}" - dates.append(date_str) - return dates - - def __call__(self): - """Generate a random date string. - - Returns: - str: A random date string. - - Raises: - RuntimeError: If all unique dates have been generated (when with_replacement is False). - """ - if self.with_replacement: - return self.rng.choice(self.possible_dates) - - if len(self.generated) >= self.total_possible: - raise RuntimeError("All unique dates have been generated.") - - while True: - date = self.rng.choice(self.possible_dates) - if date not in self.generated: - self.generated.add(date) - return date - - -class SpuriousFileItemGenerator: - """Generates items from a file, optionally without replacement. - - Each non-empty line in the file is considered a distinct item. - """ - - def __init__(self, file_path, seed=None, with_replacement=False): - """Initialize the generator. - - Args: - file_path (str): Path to the file with one item per line. - seed (int, optional): Seed for reproducibility. - with_replacement (bool): Whether to allow duplicates. - """ - self.rng = random.Random(seed) - self.with_replacement = with_replacement - self.generated = set() - - with open(file_path, "r", encoding="utf-8") as f: - self.items = [line.strip() for line in f if line.strip()] - - if not self.items: - raise ValueError("File is empty or contains only blank lines.") - - self.total_possible = len(self.items) - - def __call__(self): - """Generate a random item from the file. - - Returns: - str: A random item. - - Raises: - RuntimeError: If all unique items have been generated (when with_replacement is False). - """ - if self.with_replacement: - return self.rng.choice(self.items) - - if len(self.generated) >= self.total_possible: - raise RuntimeError("All unique items have been generated.") - - while True: - item = self.rng.choice(self.items) - if item not in self.generated: - self.generated.add(item) - return item diff --git a/stable_pretraining/data/spurious_corr/tests/test_date_generator.py b/stable_pretraining/data/spurious_corr/tests/test_date_generator.py deleted file mode 100644 index 3dd2045c..00000000 --- a/stable_pretraining/data/spurious_corr/tests/test_date_generator.py +++ /dev/null @@ -1,59 +0,0 @@ -import pytest -from stable_pretraining.data.spurious_corr.generators import SpuriousDateGenerator - - -@pytest.mark.unit -def test_no_duplicates_with_replacement_false(): - gen = SpuriousDateGenerator( - year_range=(2020, 2020), seed=123, with_replacement=False - ) - generated = set() - num_samples = 365 - for _ in range(num_samples): - date = gen() - assert date not in generated - generated.add(date) - - -@pytest.mark.unit -def test_same_seed_produces_same_sequence_no_replacement(): - g1 = SpuriousDateGenerator(year_range=(1900, 2100), seed=42, with_replacement=False) - g2 = SpuriousDateGenerator(year_range=(1900, 2100), seed=42, with_replacement=False) - - dates1 = [g1() for _ in range(10000)] - dates2 = [g2() for _ in range(10000)] - - assert dates1 == dates2 - - -@pytest.mark.unit -def test_same_seed_produces_same_sequence_with_replacement(): - g1 = SpuriousDateGenerator(year_range=(1900, 2100), seed=42, with_replacement=True) - g2 = SpuriousDateGenerator(year_range=(1900, 2100), seed=42, with_replacement=True) - - dates1 = [g1() for _ in range(10000)] - dates2 = [g2() for _ in range(10000)] - - assert dates1 == dates2 - - -@pytest.mark.unit -def test_different_seed_produces_different_sequence_no_replacement(): - g1 = SpuriousDateGenerator(year_range=(1900, 2100), seed=42, with_replacement=False) - g2 = SpuriousDateGenerator(year_range=(1900, 2100), seed=42, with_replacement=False) - - dates1 = [g1() for _ in range(10000)] - dates2 = [g2() for _ in range(10000)] - - assert dates1 == dates2 - - -@pytest.mark.unit -def test_different_seed_produces_different_sequence_with_replacement(): - g1 = SpuriousDateGenerator(year_range=(1900, 2100), seed=42, with_replacement=True) - g2 = SpuriousDateGenerator(year_range=(1900, 2100), seed=42, with_replacement=True) - - dates1 = [g1() for _ in range(10000)] - dates2 = [g2() for _ in range(10000)] - - assert dates1 == dates2 diff --git a/stable_pretraining/data/spurious_corr/tests/test_fileitem_generator.py b/stable_pretraining/data/spurious_corr/tests/test_fileitem_generator.py deleted file mode 100644 index 7b8fc91c..00000000 --- a/stable_pretraining/data/spurious_corr/tests/test_fileitem_generator.py +++ /dev/null @@ -1,79 +0,0 @@ -import pytest -import tempfile -import os -from stable_pretraining.data.spurious_corr.generators import SpuriousFileItemGenerator - - -# Utility to create a temp file with test content -@pytest.fixture -def temp_file(): - content = "\n".join(f"item_{i}" for i in range(100)) - with tempfile.NamedTemporaryFile(mode="w+", delete=False) as f: - f.write(content) - f.flush() - yield f.name - os.remove(f.name) - - -@pytest.mark.unit -def test_no_duplicates_with_replacement_false(temp_file): - gen = SpuriousFileItemGenerator(temp_file, seed=123, with_replacement=False) - generated = set() - for _ in range(100): - item = gen() - assert item not in generated - generated.add(item) - - with pytest.raises(RuntimeError): - gen() # Should raise after exhausting all items - - -@pytest.mark.unit -def test_same_seed_produces_same_sequence_no_replacement(temp_file): - g1 = SpuriousFileItemGenerator(temp_file, seed=42, with_replacement=False) - g2 = SpuriousFileItemGenerator(temp_file, seed=42, with_replacement=False) - - items1 = [g1() for _ in range(100)] - items2 = [g2() for _ in range(100)] - - assert items1 == items2 - - -@pytest.mark.unit -def test_same_seed_produces_same_sequence_with_replacement(temp_file): - g1 = SpuriousFileItemGenerator(temp_file, seed=42, with_replacement=True) - g2 = SpuriousFileItemGenerator(temp_file, seed=42, with_replacement=True) - - items1 = [g1() for _ in range(100)] - items2 = [g2() for _ in range(100)] - - assert items1 == items2 - - -@pytest.mark.unit -def test_different_seed_produces_different_sequence_with_replacement(temp_file): - g1 = SpuriousFileItemGenerator(temp_file, seed=1, with_replacement=True) - g2 = SpuriousFileItemGenerator(temp_file, seed=2, with_replacement=True) - - items1 = [g1() for _ in range(100)] - items2 = [g2() for _ in range(100)] - - assert items1 != items2 # Very unlikely to match by chance - - -@pytest.mark.unit -def test_raises_on_empty_file(): - with tempfile.NamedTemporaryFile(mode="w+", delete=False) as f: - pass # empty file - with pytest.raises(ValueError): - SpuriousFileItemGenerator(f.name) - os.remove(f.name) - - -@pytest.mark.unit -def test_generator_raises_after_all_items_used(temp_file): - gen = SpuriousFileItemGenerator(temp_file, seed=42, with_replacement=False) - for _ in range(100): # exhaust all items - _ = gen() - with pytest.raises(RuntimeError, match="All unique items have been generated."): - gen() diff --git a/stable_pretraining/data/spurious_corr/transform.py b/stable_pretraining/data/spurious_corr/transform.py deleted file mode 100644 index 421bf0e2..00000000 --- a/stable_pretraining/data/spurious_corr/transform.py +++ /dev/null @@ -1,54 +0,0 @@ -"""transform.py. - -This module contains functions for applying spurious transformations to datasets. -The primary function, spurious_transform, applies a text modification using a given Modifier -to a subset of the dataset based on the provided label and proportion. -""" - -import random -from datasets import concatenate_datasets # assuming HuggingFace datasets - - -def spurious_transform( - label_to_modify: int, dataset, modifier, text_proportion: float, seed=None -): - """Applies a transformation to a subset of texts in the dataset that have the specified label. - - Args: - label_to_modify (int): The label of the text to modify. - dataset: The dataset containing the text data. - modifier: An instance of a Modifier subclass that modifies (text, label). - text_proportion (float): Proportion of texts to transform using the modifier (between 0 and 1). - seed (int, optional): Seed for random sampling reproducibility. - - Returns: - Dataset: A new dataset with the transformations applied to examples with the given label. - """ - dataset_to_modify = dataset.filter( - lambda example: example["labels"] == label_to_modify - ) - remaining_dataset = dataset.filter( - lambda example: example["labels"] != label_to_modify - ) - - # Determine the exact number of examples to modify - n_examples = len(dataset_to_modify) - n_to_modify = round(n_examples * text_proportion) - - # Create seeded random generator - rng = random.Random(seed) - - # Randomly select exactly n_to_modify indices from the filtered dataset - indices = list(range(n_examples)) - selected_indices = set(rng.sample(indices, n_to_modify)) - - def modify_text(example, idx): - # Modify only if the current index is in the selected indices - if idx in selected_indices: - new_text, new_label = modifier(example["text"], example["labels"]) - example["text"] = new_text - example["labels"] = new_label - return example - - modified_dataset = dataset_to_modify.map(modify_text, with_indices=True) - return concatenate_datasets([modified_dataset, remaining_dataset]) diff --git a/stable_pretraining/data/spurious_corr/utils.py b/stable_pretraining/data/spurious_corr/utils.py deleted file mode 100644 index 8d5da09d..00000000 --- a/stable_pretraining/data/spurious_corr/utils.py +++ /dev/null @@ -1,108 +0,0 @@ -"""utils.py. - -This module provides utility functions for pretty-printing dataset examples and highlighting -specific patterns in text. These functions are useful for debugging and visualizing the modifications -applied to the dataset. -""" - -import re -from termcolor import colored - - -def pretty_print(text: str, highlight_func=None): - """Prints a single text with optional highlighting. - - Args: - text (str): The text to print. - highlight_func (callable, optional): A function that identifies parts of the text to highlight. - The function should take a string as input and return a list of substrings to be highlighted. - """ - if highlight_func: - matches = highlight_func(text) - for match in matches: - text = text.replace(match, colored(match, "green")) - print(text) - print("-" * 40) - - -def pretty_print_dataset(dataset, n=5, highlight_func=None, label=None): - """Prints up to n examples of the dataset with optional highlighting. - - If a label is provided, only examples with that label are printed. - - Args: - dataset: A dataset containing text and labels. - n (int): Maximum number of examples to print (default is 5). - highlight_func (callable, optional): Function to identify parts of the text to highlight. - label (int, optional): If provided, only examples with this label are printed. - """ - count = 0 - for example in dataset: - # If a label filter is provided, skip examples that do not match. - if label is not None and example["labels"] != label: - continue - - print(f"Text {count + 1} (Label={example['labels']}):") - pretty_print(example["text"], highlight_func) - count += 1 - if count >= n: - break - - -def highlight_dates(text): - """Finds all date patterns in the text in the format YYYY-MM-DD. - - Args: - text (str): The text to search. - - Returns: - list: A list of date strings found in the text. - """ - return re.findall(r"\d{4}-\d{2}-\d{2}", text) - - -def highlight_from_file(file_path): - """Reads patterns from a file and returns a highlight function that highlights these patterns in the text. - - Args: - file_path (str): Path to the file containing patterns. - - Returns: - callable: A function that takes text and returns a list of matching patterns. - """ - with open(file_path, "r", encoding="utf-8") as file: - patterns = [line.strip() for line in file if line.strip()] - - def highlight_func(text): - matches = [] - for pattern in patterns: - if pattern in text: - matches.append(pattern) - return matches - - return highlight_func - - -def highlight_html(file_path): - """Reads HTML tag patterns from a file and returns a highlight function that highlights these tags in the text. - - Args: - file_path (str): Path to the file containing HTML tag patterns. - - Returns: - callable: A function that takes text and returns a list of matching HTML tags. - """ - with open(file_path, "r", encoding="utf-8") as file: - patterns = [line.strip() for line in file if line.strip()] - tags = [] - for line in patterns: - tags.extend(line.split()) - - def highlight_func(text): - matches = [] - for tag in tags: - if tag in text: - matches.append(tag) - return matches - - return highlight_func diff --git a/stable_pretraining/data/spurious_dataset.py b/stable_pretraining/data/spurious_dataset.py new file mode 100644 index 00000000..59942d53 --- /dev/null +++ b/stable_pretraining/data/spurious_dataset.py @@ -0,0 +1,67 @@ +"""spurious_dataset.py. + +Unified module for constructing spurious correlation datasets. + +All spurious injections are now file-based via ItemInjection.from_file or similar. +""" + +import random +from datasets import concatenate_datasets + +# The modifiers come from your modifiers.py file +from transforms import CompositeModifier + + +class SpuriousDatasetBuilder: + """Builds datasets with spurious correlations by applying Modifier objects.""" + + def __init__(self, seed=None): + """Constructor for the DatasetBuilder. + + Args: + seed (int, optional): Seed for reproducibility. + """ + self.rng = random.Random(seed) + + def _apply_modifier_to_subset(self, dataset, label_to_modify, modifier, proportion): + """Apply a Modifier (or CompositeModifier) to a proportion of samples with a given label.""" + dataset_to_modify = dataset.filter(lambda ex: ex["labels"] == label_to_modify) + remaining_dataset = dataset.filter(lambda ex: ex["labels"] != label_to_modify) + + n_examples = len(dataset_to_modify) + n_to_modify = round(n_examples * proportion) + indices = list(range(n_examples)) + selected_indices = set(self.rng.sample(indices, n_to_modify)) + + def modify_example(example, idx): + if idx in selected_indices: + new_text, new_label = modifier(example["text"], example["labels"]) + example["text"] = new_text + example["labels"] = new_label + return example + + modified_subset = dataset_to_modify.map(modify_example, with_indices=True) + return concatenate_datasets([modified_subset, remaining_dataset]) + + def build_spurious_dataset( + self, dataset, modifiers_config, label_to_modify, proportion + ): + """Construct a spurious dataset. + + Args: + dataset: Hugging Face Dataset object with "text" and "labels". + modifiers_config (list[Modifier] or Modifier): One or more modifiers to apply. + label_to_modify (int): Which label group to modify. + proportion (float): Proportion of examples within that label to modify (0-1). + + Returns: + Dataset: Modified dataset. + """ + if isinstance(modifiers_config, list): + modifier = CompositeModifier(modifiers_config) + else: + modifier = modifiers_config + + return self._apply_modifier_to_subset( + dataset, label_to_modify, modifier, proportion + ) diff --git a/stable_pretraining/data/utils.py b/stable_pretraining/data/utils.py index 851653f5..f5baf982 100644 --- a/stable_pretraining/data/utils.py +++ b/stable_pretraining/data/utils.py @@ -6,6 +6,8 @@ import itertools import math +import random +import calendar import warnings from collections.abc import Sequence from typing import Optional, Union, cast @@ -155,3 +157,43 @@ def apply_masks(x: torch.Tensor, *masks: torch.Tensor) -> torch.Tensor: out = x_expanded.gather(2, idx_expanded) return out.reshape(B * M, K, D) + + +def write_random_dates( + file_path, n=1000, year_range=(1100, 2600), seed=None, with_replacement=False +): + """Writes `n` random valid dates to a text file, one per line. + + Args: + file_path (str): Destination file path. + n (int): Number of random dates to generate. + year_range (tuple): Range of valid years (start, end). + seed (int): Optional random seed for reproducibility. + with_replacement (bool): Whether to allow for dates to repeat + """ + rng = random.Random(seed) + start_year, end_year = year_range + + # Generate all possible dates in range + all_dates = [] + for year in range(start_year, end_year + 1): + for month in range(1, 13): + _, max_day = calendar.monthrange(year, month) + for day in range(1, max_day + 1): + all_dates.append(f"{year}-{month:02d}-{day:02d}") + + total_possible = len(all_dates) + if not with_replacement and n > total_possible: + raise ValueError( + f"Cannot generate {n} unique dates — only {total_possible} unique dates possible " + f"for year range {year_range}." + ) + + if with_replacement: + chosen = [rng.choice(all_dates) for _ in range(n)] + else: + chosen = rng.sample(all_dates, k=n) + + with open(file_path, "w", encoding="utf-8") as f: + for d in chosen: + f.write(d + "\n") diff --git a/stable_pretraining/tests/unit/test_file_based_sampling.py b/stable_pretraining/tests/unit/test_file_based_sampling.py new file mode 100644 index 00000000..49d6a9a5 --- /dev/null +++ b/stable_pretraining/tests/unit/test_file_based_sampling.py @@ -0,0 +1,71 @@ +import pytest +import tempfile +import os +from stable_pretraining.data.utils import load_items_from_file, write_random_dates + + +@pytest.fixture +def temp_file(): + """Create a temporary file with predictable content.""" + content = "\n".join(f"item_{i}" for i in range(100)) + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as f: + f.write(content) + f.flush() + yield f.name + os.remove(f.name) + + +@pytest.mark.unit +def test_load_items_from_file_no_replacement(temp_file): + items = load_items_from_file(temp_file, seed=123, with_replacement=False) + assert len(items) == 100 + assert len(set(items)) == 100 # unique items only + + +@pytest.mark.unit +def test_load_items_from_file_with_replacement(temp_file): + items = load_items_from_file(temp_file, seed=123, with_replacement=True) + assert len(items) == 100 + # replacement allows duplicates + assert len(set(items)) <= 100 + + +@pytest.mark.unit +def test_load_items_reproducibility_with_same_seed(temp_file): + items1 = load_items_from_file(temp_file, seed=42, with_replacement=True) + items2 = load_items_from_file(temp_file, seed=42, with_replacement=True) + assert items1 == items2 + + +@pytest.mark.unit +def test_load_items_different_seed_produces_different_sequence(temp_file): + items1 = load_items_from_file(temp_file, seed=1, with_replacement=True) + items2 = load_items_from_file(temp_file, seed=2, with_replacement=True) + assert items1 != items2 # extremely unlikely to match + + +@pytest.mark.unit +def test_load_items_from_empty_file_raises(): + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as f: + path = f.name + with pytest.raises(ValueError): + load_items_from_file(path) + os.remove(path) + + +@pytest.mark.unit +def test_write_random_dates_creates_valid_file(): + """Ensure write_random_dates writes unique or repeatable dates.""" + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "dates.txt") + write_random_dates( + path, + num_samples=50, + year_range=(2020, 2020), + seed=42, + with_replacement=False, + ) + with open(path, "r") as f: + dates = [line.strip() for line in f if line.strip()] + assert len(dates) == len(set(dates)) + assert all(date.startswith("2020-") for date in dates) diff --git a/stable_pretraining/data/spurious_corr/tests/test_html_injection.py b/stable_pretraining/tests/unit/test_html_injection.py similarity index 100% rename from stable_pretraining/data/spurious_corr/tests/test_html_injection.py rename to stable_pretraining/tests/unit/test_html_injection.py diff --git a/stable_pretraining/data/spurious_corr/tests/test_item_injection.py b/stable_pretraining/tests/unit/test_item_injection.py similarity index 100% rename from stable_pretraining/data/spurious_corr/tests/test_item_injection.py rename to stable_pretraining/tests/unit/test_item_injection.py diff --git a/stable_pretraining/tests/unit/test_write_random_dates.py b/stable_pretraining/tests/unit/test_write_random_dates.py new file mode 100644 index 00000000..9f53babf --- /dev/null +++ b/stable_pretraining/tests/unit/test_write_random_dates.py @@ -0,0 +1,80 @@ +import pytest +from pathlib import Path +from stable_pretraining.data.utils import write_random_dates + + +@pytest.mark.unit +def test_no_duplicates_with_replacement_false(tmp_path: Path): + """Ensure write_random_dates produces unique dates when with_replacement=False.""" + output_file = tmp_path / "dates.txt" + write_random_dates( + output_file, n=365, year_range=(2020, 2020), seed=123, with_replacement=False + ) + + dates = [line.strip() for line in open(output_file)] + assert len(dates) == len(set(dates)), "Duplicates found when with_replacement=False" + + +@pytest.mark.unit +def test_with_replacement_allows_duplicates(tmp_path: Path): + """Ensure write_random_dates allows duplicates when with_replacement=True.""" + output_file = tmp_path / "dates.txt" + write_random_dates( + output_file, n=1000, year_range=(2020, 2020), seed=42, with_replacement=True + ) + + dates = [line.strip() for line in open(output_file)] + # With replacement, duplicates should appear in large samples + assert len(set(dates)) < len(dates), ( + "No duplicates found when with_replacement=True" + ) + + +@pytest.mark.unit +def test_same_seed_produces_same_output(tmp_path: Path): + """Ensure deterministic output for same seed.""" + f1 = tmp_path / "dates1.txt" + f2 = tmp_path / "dates2.txt" + + write_random_dates( + f1, n=100, year_range=(1900, 1905), seed=42, with_replacement=True + ) + write_random_dates( + f2, n=100, year_range=(1900, 1905), seed=42, with_replacement=True + ) + + d1 = open(f1).read().splitlines() + d2 = open(f2).read().splitlines() + + assert d1 == d2, "Outputs differ for same seed" + + +@pytest.mark.unit +def test_different_seed_produces_different_output(tmp_path: Path): + """Ensure non-deterministic output for different seeds.""" + f1 = tmp_path / "dates1.txt" + f2 = tmp_path / "dates2.txt" + + write_random_dates( + f1, n=100, year_range=(1900, 1905), seed=42, with_replacement=True + ) + write_random_dates( + f2, n=100, year_range=(1900, 1905), seed=99, with_replacement=True + ) + + d1 = open(f1).read().splitlines() + d2 = open(f2).read().splitlines() + + assert d1 != d2, "Outputs identical for different seeds" + + +@pytest.mark.unit +def test_number_of_lines_written(tmp_path: Path): + """Ensure exactly n lines are written.""" + output_file = tmp_path / "dates.txt" + n = 50 + write_random_dates( + output_file, n=n, year_range=(2020, 2020), seed=0, with_replacement=False + ) + lines = open(output_file).read().splitlines() + assert len(lines) == n, f"Expected {n} lines, got {len(lines)}" From 9e52fb2110859ae9fbe92c1adc0b22c990d1bf63 Mon Sep 17 00:00:00 2001 From: "marcel_mateos_salles@brown.edu" Date: Wed, 22 Oct 2025 15:45:22 -0400 Subject: [PATCH 12/27] updating tests --- .../tests/unit/test_file_based_sampling.py | 88 ++++++++----------- 1 file changed, 35 insertions(+), 53 deletions(-) diff --git a/stable_pretraining/tests/unit/test_file_based_sampling.py b/stable_pretraining/tests/unit/test_file_based_sampling.py index 49d6a9a5..4c20aeec 100644 --- a/stable_pretraining/tests/unit/test_file_based_sampling.py +++ b/stable_pretraining/tests/unit/test_file_based_sampling.py @@ -1,71 +1,53 @@ import pytest import tempfile import os -from stable_pretraining.data.utils import load_items_from_file, write_random_dates - - -@pytest.fixture -def temp_file(): - """Create a temporary file with predictable content.""" - content = "\n".join(f"item_{i}" for i in range(100)) - with tempfile.NamedTemporaryFile(mode="w+", delete=False) as f: - f.write(content) - f.flush() - yield f.name - os.remove(f.name) +from stable_pretraining.data.utils import write_random_dates +from stable_pretraining.data.spurious_corr.modifiers import ItemInjection @pytest.mark.unit -def test_load_items_from_file_no_replacement(temp_file): - items = load_items_from_file(temp_file, seed=123, with_replacement=False) - assert len(items) == 100 - assert len(set(items)) == 100 # unique items only +def test_write_random_dates_creates_file(): + with tempfile.TemporaryDirectory() as tmpdir: + out_file = os.path.join(tmpdir, "dates.txt") + write_random_dates(out_file, n=10, seed=42, with_replacement=False) + assert os.path.exists(out_file) + with open(out_file, "r") as f: + lines = [line.strip() for line in f if line.strip()] -@pytest.mark.unit -def test_load_items_from_file_with_replacement(temp_file): - items = load_items_from_file(temp_file, seed=123, with_replacement=True) - assert len(items) == 100 - # replacement allows duplicates - assert len(set(items)) <= 100 + assert len(lines) == 10 + assert all("-" in d for d in lines) # looks like YYYY-MM-DD @pytest.mark.unit -def test_load_items_reproducibility_with_same_seed(temp_file): - items1 = load_items_from_file(temp_file, seed=42, with_replacement=True) - items2 = load_items_from_file(temp_file, seed=42, with_replacement=True) - assert items1 == items2 +def test_item_injection_reads_from_file_and_injects_correctly(): + with tempfile.TemporaryDirectory() as tmpdir: + # Write fake items + file_path = os.path.join(tmpdir, "items.txt") + with open(file_path, "w") as f: + f.write("A\nB\nC\n") + # Create modifier + modifier = ItemInjection(file_path=file_path, token="TEST") -@pytest.mark.unit -def test_load_items_different_seed_produces_different_sequence(temp_file): - items1 = load_items_from_file(temp_file, seed=1, with_replacement=True) - items2 = load_items_from_file(temp_file, seed=2, with_replacement=True) - assert items1 != items2 # extremely unlikely to match + text, label = modifier("original text", 0) + # The exact assertion depends on how your ItemInjection modifies the text. + # Here's a generic check: + assert isinstance(text, str) + assert "TEST" in text or any(x in text for x in ["A", "B", "C"]) @pytest.mark.unit -def test_load_items_from_empty_file_raises(): - with tempfile.NamedTemporaryFile(mode="w+", delete=False) as f: - path = f.name - with pytest.raises(ValueError): - load_items_from_file(path) - os.remove(path) +def test_item_injection_is_deterministic_with_seed(): + with tempfile.TemporaryDirectory() as tmpdir: + file_path = os.path.join(tmpdir, "items.txt") + with open(file_path, "w") as f: + f.write("X\nY\nZ\n") + m1 = ItemInjection(file_path=file_path, token="SPURIOUS", seed=123) + m2 = ItemInjection(file_path=file_path, token="SPURIOUS", seed=123) -@pytest.mark.unit -def test_write_random_dates_creates_valid_file(): - """Ensure write_random_dates writes unique or repeatable dates.""" - with tempfile.TemporaryDirectory() as tmpdir: - path = os.path.join(tmpdir, "dates.txt") - write_random_dates( - path, - num_samples=50, - year_range=(2020, 2020), - seed=42, - with_replacement=False, - ) - with open(path, "r") as f: - dates = [line.strip() for line in f if line.strip()] - assert len(dates) == len(set(dates)) - assert all(date.startswith("2020-") for date in dates) + out1 = [m1("base text", 1)[0] for _ in range(10)] + out2 = [m2("base text", 1)[0] for _ in range(10)] + + assert out1 == out2 From e66b27612d5d98096bda037255b955779271769c Mon Sep 17 00:00:00 2001 From: "marcel_mateos_salles@brown.edu" Date: Wed, 22 Oct 2025 15:51:25 -0400 Subject: [PATCH 13/27] updating tests --- stable_pretraining/tests/unit/test_file_based_sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_pretraining/tests/unit/test_file_based_sampling.py b/stable_pretraining/tests/unit/test_file_based_sampling.py index 4c20aeec..65690d51 100644 --- a/stable_pretraining/tests/unit/test_file_based_sampling.py +++ b/stable_pretraining/tests/unit/test_file_based_sampling.py @@ -2,7 +2,7 @@ import tempfile import os from stable_pretraining.data.utils import write_random_dates -from stable_pretraining.data.spurious_corr.modifiers import ItemInjection +from stable_pretraining.data.spurious_corr.transforms import ItemInjection @pytest.mark.unit From 3d3c9e5df44bbd1be53dcc889ae5d5f28980d951 Mon Sep 17 00:00:00 2001 From: "marcel_mateos_salles@brown.edu" Date: Wed, 22 Oct 2025 15:51:57 -0400 Subject: [PATCH 14/27] updating tests --- stable_pretraining/tests/unit/test_file_based_sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_pretraining/tests/unit/test_file_based_sampling.py b/stable_pretraining/tests/unit/test_file_based_sampling.py index 65690d51..6b8d94f3 100644 --- a/stable_pretraining/tests/unit/test_file_based_sampling.py +++ b/stable_pretraining/tests/unit/test_file_based_sampling.py @@ -2,7 +2,7 @@ import tempfile import os from stable_pretraining.data.utils import write_random_dates -from stable_pretraining.data.spurious_corr.transforms import ItemInjection +from stable_pretraining.data.transforms import ItemInjection @pytest.mark.unit From d74bdddbcb3f10c4e553d2d78c929125fb806b8a Mon Sep 17 00:00:00 2001 From: "marcel_mateos_salles@brown.edu" Date: Wed, 22 Oct 2025 15:59:02 -0400 Subject: [PATCH 15/27] updating tests --- stable_pretraining/tests/unit/test_file_based_sampling.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/stable_pretraining/tests/unit/test_file_based_sampling.py b/stable_pretraining/tests/unit/test_file_based_sampling.py index 6b8d94f3..673b5d10 100644 --- a/stable_pretraining/tests/unit/test_file_based_sampling.py +++ b/stable_pretraining/tests/unit/test_file_based_sampling.py @@ -28,7 +28,7 @@ def test_item_injection_reads_from_file_and_injects_correctly(): f.write("A\nB\nC\n") # Create modifier - modifier = ItemInjection(file_path=file_path, token="TEST") + modifier = ItemInjection.from_file(file_path, seed=42) text, label = modifier("original text", 0) # The exact assertion depends on how your ItemInjection modifies the text. @@ -44,8 +44,8 @@ def test_item_injection_is_deterministic_with_seed(): with open(file_path, "w") as f: f.write("X\nY\nZ\n") - m1 = ItemInjection(file_path=file_path, token="SPURIOUS", seed=123) - m2 = ItemInjection(file_path=file_path, token="SPURIOUS", seed=123) + m1 = ItemInjection.from_file(file_path, seed=42) + m2 = ItemInjection.from_file(file_path, seed=42) out1 = [m1("base text", 1)[0] for _ in range(10)] out2 = [m2("base text", 1)[0] for _ in range(10)] From 35013db8c84c8f5d8eed178671619511f26e334c Mon Sep 17 00:00:00 2001 From: "marcel_mateos_salles@brown.edu" Date: Thu, 23 Oct 2025 11:24:10 -0400 Subject: [PATCH 16/27] further refactored code to make them all transformations --- stable_pretraining/data/transforms.py | 653 +++++++----------- .../tests/unit/test_file_based_sampling.py | 57 +- .../tests/unit/test_html_injection.py | 207 +----- .../tests/unit/test_item_injection.py | 134 ++-- 4 files changed, 423 insertions(+), 628 deletions(-) diff --git a/stable_pretraining/data/transforms.py b/stable_pretraining/data/transforms.py index 8c98841b..9ad910b3 100644 --- a/stable_pretraining/data/transforms.py +++ b/stable_pretraining/data/transforms.py @@ -20,395 +20,6 @@ from stable_pretraining.data.masking import multi_block_mask -# ============================================================ -# ===================== TEXT MODIFIERS ======================= -# ============================================================ - - -class Modifier: - """Base class for applying modifications/corruptions to text-label pairs. - - Subclasses must implement the __call__ method to define specific transformations. - - Example: - class MyModifier(Modifier): - def __call__(self, text: str, label: Any) -> tuple[str, Any]: - # custom transformation here - return transformed_text, transformed_label - """ - - def __call__(self, text: str, label): - """Apply the transformation to a single text-label pair. - - Args: - text (str): The input text to transform. - label: The associated label. - - Returns: - tuple: (transformed_text, transformed_label) - """ - raise NotImplementedError("Subclasses must implement __call__") - - -class CompositeModifier: - """CompositeModifier chains multiple Modifier instances together. - - Each modifier from the list is applied sequentially to the text. This enables - the combination of various transformations or injections into one composite operation. - """ - - def __init__(self, modifiers: list): - """Initialize a CompositeModifier instance. - - Args: - modifiers (list): A list of modifier instances (subclasses of Modifier) - to be applied sequentially. - """ - self.modifiers = modifiers - - def __call__(self, text: str, label): - """Apply all modifiers in sequence to the given (text, label). - - Args: - text (str): The input text. - label: The associated label. - - Returns: - tuple: The modified (text, label) pair after all transformations. - """ - for modifier in self.modifiers: - text, label = modifier(text, label) - return text, label - - -class ItemInjection(Modifier): - """A Modifier that injects items into text. - - This class supports creation via three different approaches: - - from_list: Using a predefined list of injection items. - - from_file: Reading injection items from a file. - - from_function: Using a custom function to generate injections. - """ - - def __init__( - self, - injection_source, - location: str = "random", - token_proportion: float = 0.1, - seed=None, - _rng=None, - ): - """Initialize an ItemInjection instance. - - Args: - injection_source (callable): A function that returns an injection token. - location (str): Where to inject the token ("beginning", "random", "end"). - token_proportion (float): Proportion of tokens in the text to be affected. - seed (int, optional): Seed for reproducibility. - """ - assert callable(injection_source), "injection_source must be callable" - self.injection_source = injection_source - self.location = location - self.token_proportion = token_proportion - self.rng = _rng or random.Random(seed) - - assert 0 <= token_proportion <= 1, "token_proportion must be between 0 and 1" - assert location in {"beginning", "random", "end"}, ( - "location must be 'beginning', 'random', or 'end'" - ) - - def __call__(self, text: str, label): - """Inject tokens into the text at specified locations. - - Args: - text (str): The input text to modify. - label: The original label (unchanged). - - Returns: - tuple: The modified text and the original label. - """ - words = text.split() - num_tokens = len(words) - - # Ensure at least one token is injected - num_to_inject = max(1, int(num_tokens * self.token_proportion)) - - injections = [self.injection_source() for _ in range(num_to_inject)] - - if self.location == "beginning": - words = injections + words - elif self.location == "end": - words = words + injections - elif self.location == "random": - for injection in injections: - pos = self.rng.randint(0, len(words)) - words.insert(pos, injection) - - return " ".join(words), label # return modified text and unchanged label - - @classmethod - def from_list( - cls, - items: list, - location: str = "random", - token_proportion: float = 0.1, - seed=None, - ): - """Create an ItemInjection instance using a predefined list of tokens. - - Args: - items (list): List of token strings to choose from. - location (str): Where to inject tokens ("beginning", "random", "end"). - token_proportion (float): Proportion of text tokens to be affected. - seed (int, optional): Seed for reproducibility. - - Returns: - ItemInjection: Configured instance. - """ - rng = random.Random(seed) - - def injection_source(): - return rng.choice(items) - - return cls( - injection_source, - location=location, - token_proportion=token_proportion, - seed=seed, - _rng=rng, - ) - - @classmethod - def from_file( - cls, - file_path: str, - location: str = "random", - token_proportion: float = 0.1, - seed=None, - ): - """Create an ItemInjection instance using tokens read from a file. - - Each non-empty line becomes a potential injection item. - - Args: - file_path (str): Path to the file with one token per line. - location (str): Where to inject tokens. - token_proportion (float): Proportion of tokens to inject. - seed (int, optional): Seed for reproducibility. - - Returns: - ItemInjection: Configured instance. - """ - with open(file_path, "r", encoding="utf-8") as file: - items = [line.strip() for line in file if line.strip()] - - rng = random.Random(seed) - - def injection_source(): - return rng.choice(items) - - return cls( - injection_source, - location=location, - token_proportion=token_proportion, - _rng=rng, - ) - - @classmethod - def from_function( - cls, - injection_func, - location: str = "random", - token_proportion: float = 0.1, - seed=None, - ): - """Create an ItemInjection instance using a custom function to generate injections. - - Args: - injection_func (callable): Function that returns a new injection token each time. - location (str): Where to inject tokens. - token_proportion (float): Proportion of text to inject into. - seed (int, optional): Seed for reproducibility (used only for insertion position). - - Returns: - ItemInjection: Configured instance. - """ - assert callable(injection_func), "injection_func must be callable" - return cls( - injection_func, - location=location, - token_proportion=token_proportion, - seed=seed, - ) - - -class HTMLInjection(Modifier): - """A Modifier that injects html into text. - - This class supports creation via two different approaches: - - from_list: Using a predefined list of injection items. - - from_file: Reading injection items from a file. - """ - - def __init__( - self, - file_path: str, - location: str = "random", - level: int = None, - token_proportion: float = None, - seed=None, - ): - with open(file_path, "r", encoding="utf-8") as f: - self.tags = [line.strip() for line in f if line.strip()] - self.location = location - self.level = level - self.token_proportion = token_proportion - self.rng = random.Random(seed) - - if token_proportion is not None: - assert 0 < token_proportion <= 1, "token_proportion must be between 0 and 1" - - @classmethod - def from_file( - cls, - file_path: str, - location: str = "random", - level: int = None, - token_proportion: float = None, - seed=None, - ): - return cls( - file_path, - location=location, - level=level, - token_proportion=token_proportion, - seed=seed, - ) - - @classmethod - def from_list( - cls, - tags: list, - location: str = "random", - level: int = None, - token_proportion: float = None, - seed=None, - ): - instance = cls.__new__(cls) - instance.tags = tags - instance.location = location - instance.level = level - instance.token_proportion = token_proportion - instance.rng = random.Random(seed) - - if token_proportion is not None: - assert 0 < token_proportion <= 1, "token_proportion must be between 0 and 1" - - return instance - - def _choose_tag(self): - """Randomly choose a tag from the loaded list. - - Returns: - tuple: (opening_tag, closing_tag or None) - """ - line = self.rng.choice(self.tags) - parts = line.split() - if len(parts) >= 2: - return parts[0], parts[1] - else: - return parts[0], None - - def _inject_into_tokens(self, tokens, location): - tokens = tokens[:] - n = len(tokens) - - if self.token_proportion is None: - opening, closing = self._choose_tag() - return self._inject_with_tags(tokens, opening, closing, location) - - # Otherwise, inject up to token_proportion of total tokens - num_insertions = max(1, int(n * self.token_proportion)) - for _ in range(num_insertions): - opening, closing = self._choose_tag() - tokens = self._inject_with_tags(tokens, opening, closing, location) - return tokens - - def _inject_with_tags(self, tokens, opening, closing, location): - if location == "beginning": - new_tokens = [opening] + tokens - if closing: - pos = self.rng.randint(1, len(new_tokens)) - new_tokens.insert(pos, closing) - return new_tokens - - elif location == "end": - new_tokens = tokens[:] - pos = self.rng.randint(0, len(new_tokens)) - new_tokens.insert(pos, opening) - if closing: - new_tokens.append(closing) - return new_tokens - - elif location == "random": - new_tokens = tokens[:] - pos_open = self.rng.randint(0, len(new_tokens)) - new_tokens.insert(pos_open, opening) - if closing: - pos_close = self.rng.randint(pos_open + 1, len(new_tokens)) - new_tokens.insert(pos_close, closing) - return new_tokens - - return tokens - - def _inject(self, text, location): - tokens = text.split() - new_tokens = self._inject_into_tokens(tokens, location) - return " ".join(new_tokens) - - def _find_level_span(self, text, level): - """Find the first span inside the desired HTML nesting level. - - Args: - text (str): Input HTML text. - level (int): Desired nesting level. - - Returns: - tuple or None: (start, end) of the content region, or None if not found. - """ - tag_regex = re.compile(r"]*>") - stack = [] - for match in tag_regex.finditer(text): - tag_str = match.group(0) - tag_name = match.group(1) - if not tag_str.startswith(" Dict[str, torch.Tensor]: + if "idx" not in x: + x["idx"] = self._counter + self._counter += 1 + + return x + + +class ClassConditionalInjector(Transform): + """Applies transformations conditionally based on sample label. + + Args: + transformation (Transform): Transform to apply to the image. + label_key (str): Key for label in the sample dict. + target_labels (Union[int, list[int]]): Which labels to modify. + proportion (float): Fraction of samples with matching labels to modify (0-1). + total_samples (int, optional): Dataset size (for deterministic mask). + seed (int): Seed for randomization to determine which samples transformation is applied to + """ + + def __init__( + self, + transformation: Transform, + label_key: str = "label", + target_labels: Union[int, list[int]] = 0, + proportion: float = 0.5, + total_samples: Optional[int] = None, + seed: int = 42, + ): + super().__init__() + self.transformation = transformation + self.label_key = label_key + self.target_labels = ( + [target_labels] if isinstance(target_labels, int) else target_labels + ) + self.proportion = proportion + self.total_samples = total_samples + self.seed = seed + + # Precompute deterministic mask if dataset size known + if total_samples is not None: + num_to_transform = int(total_samples * proportion) + rng = torch.Generator().manual_seed(seed) + self.indices_to_transform = set( + torch.randperm(total_samples, generator=rng)[:num_to_transform].tolist() + ) + else: + self.indices_to_transform = None + + def __call__(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + label = self.nested_get(x, self.label_key) + + # Determine if we apply the transformation + should_transform = False + idx = self.nested_get(x, "idx") + if label in self.target_labels: + if self.indices_to_transform is not None: + should_transform = idx in self.indices_to_transform + else: + should_transform = random.random() < self.proportion + + if should_transform: + x = self.transformation(x) + + return x + + +class SpuriousTextInjection(Transform): + """Injects spurious tokens into text for specific target classes. + + Args: + text_key (str): The name of the key representing the text in the dataset + class_key (str): The name of the key representing the label in the dataset + class_target (int): The label to have the spurious correlation injected into + file_path (str): The path of the file to inject spurious correlations from + p (float): The proportion of samples to inject the spurious token into + location (str): The location of the text to inject the spurious token(s) into + token_proportion (float): The proportion of the original tokens available in the dataset to inject as spurious tokens + (used to determine the number injected per sample) + seed (int): Seed for reproducibility + """ + + def __init__( + self, + text_key: str, + file_path: str, + location: str = "random", + token_proportion: float = 0.1, + seed: int = None, + ): + self.text_key = text_key + self.location = location + self.token_proportion = token_proportion + self.rng = random.Random(seed) + + with open(file_path, "r", encoding="utf-8") as f: + self.items = [line.strip() for line in f if line.strip()] + + assert self.items, f"No valid lines found in {file_path}" + assert 0 <= self.token_proportion <= 1, "token_proportion must be in [0, 1]" + assert self.location in {"beginning", "random", "end"} + + def _inject(self, text: str) -> str: + words = text.split() + num_tokens = len(words) + num_to_inject = max(1, int(num_tokens * self.token_proportion)) + injections = [self.rng.choice(self.items) for _ in range(num_to_inject)] + + if self.location == "beginning": + words = injections + words + elif self.location == "end": + words = words + injections + elif self.location == "random": + for inj in injections: + pos = self.rng.randint(0, len(words)) + words.insert(pos, inj) + return " ".join(words) + + def __call__(self, x: dict) -> dict: + text = x[self.text_key] + x[self.text_key] = self._inject(text) + return x + + +class HTMLInjection(Transform): + """Injects HTML-like tags into text fields (deterministically if 'idx' present). + + This transform adds artificial HTML tokens to text data, optionally at a specific + HTML nesting level or a random position. Supports deterministic per-sample + injection when used with AddSampleIdx and ClassConditionalInjector. + + Args: + text_key (str): Key for the text field in the dataset sample. + file_path (str): Path to file containing HTML tags (each line = one tag or tag pair). + location (str): Where to inject tags ("beginning", "end", or "random"). + level (int, optional): Target HTML nesting level to inject within. + token_proportion (float, optional): The proportion of the original tokens available in the dataset + to inject as spurious tokens (used to determine the number injected per sample) + seed (int, optional): Random seed for reproducibility. + """ + + def __init__( + self, + text_key: str, + file_path: str, + location: str = "random", + level: Optional[int] = None, + token_proportion: Optional[float] = None, + seed: Optional[int] = None, + ): + super().__init__() + self.text_key = text_key + self.location = location + self.level = level + self.token_proportion = token_proportion + self.base_seed = seed if seed is not None else 0 + self.rng = random.Random(seed) + + with open(file_path, "r", encoding="utf-8") as f: + self.tags = [line.strip() for line in f if line.strip()] + + assert self.tags, f"No valid tags found in {file_path}" + if token_proportion is not None: + assert 0 < token_proportion <= 1, "token_proportion must be between 0 and 1" + assert self.location in {"beginning", "end", "random"}, "invalid location" + + # ---------- Internal helpers ---------- + def _choose_tag(self, rng): + """Select an opening/closing tag pair.""" + line = rng.choice(self.tags) + parts = line.split() + if len(parts) >= 2: + return parts[0], parts[1] + else: + return parts[0], None + + def _inject_with_tags(self, tokens, opening, closing, location, rng): + """Inject the a single tag into the text.""" + new_tokens = tokens[:] + if location == "beginning": + new_tokens.insert(0, opening) + if closing: + pos = rng.randint(1, len(new_tokens)) + new_tokens.insert(pos, closing) + elif location == "end": + pos = rng.randint(0, len(new_tokens)) + new_tokens.insert(pos, opening) + if closing: + new_tokens.append(closing) + elif location == "random": + pos_open = rng.randint(0, len(new_tokens)) + new_tokens.insert(pos_open, opening) + if closing: + pos_close = rng.randint(pos_open + 1, len(new_tokens)) + new_tokens.insert(pos_close, closing) + return new_tokens + + def _inject(self, text, rng): + """Overall injection of all tags into the text.""" + tokens = text.split() + if not tokens: + return text + + if self.token_proportion is None: + opening, closing = self._choose_tag(rng) + tokens = self._inject_with_tags( + tokens, opening, closing, self.location, rng + ) + else: + n = len(tokens) + num_insertions = max(1, int(n * self.token_proportion)) + for _ in range(num_insertions): + opening, closing = self._choose_tag(rng) + tokens = self._inject_with_tags( + tokens, opening, closing, self.location, rng + ) + return " ".join(tokens) + + def _inject_at_level(self, text, level, rng): + """Inject tags inside a specific HTML nesting level.""" + tag_regex = re.compile(r"]*>") + stack = [] + for match in tag_regex.finditer(text): + tag_str = match.group(0) + tag_name = match.group(1) + if not tag_str.startswith(" Dict[str, Any]: + """Main call function for the transformation.""" + text = x[self.text_key] + + # Deterministic per-sample RNG if idx available + if "idx" in x: + seed = self.base_seed + int(x["idx"]) + rng = random.Random(seed) + # fallback (non-deterministic but seeded globally) + else: + rng = self.rng + + if self.level is None: + x[self.text_key] = self._inject(text, rng) + else: + x[self.text_key] = self._inject_at_level(text, self.level, rng) + + return x diff --git a/stable_pretraining/tests/unit/test_file_based_sampling.py b/stable_pretraining/tests/unit/test_file_based_sampling.py index 673b5d10..812c8760 100644 --- a/stable_pretraining/tests/unit/test_file_based_sampling.py +++ b/stable_pretraining/tests/unit/test_file_based_sampling.py @@ -2,7 +2,7 @@ import tempfile import os from stable_pretraining.data.utils import write_random_dates -from stable_pretraining.data.transforms import ItemInjection +from stable_pretraining.data.transforms import SpuriousTextInjection @pytest.mark.unit @@ -20,34 +20,55 @@ def test_write_random_dates_creates_file(): @pytest.mark.unit -def test_item_injection_reads_from_file_and_injects_correctly(): +def test_spurious_text_injection_reads_from_file_and_injects_correctly(): with tempfile.TemporaryDirectory() as tmpdir: - # Write fake items - file_path = os.path.join(tmpdir, "items.txt") + # Create fake spurious tokens file + file_path = os.path.join(tmpdir, "tokens.txt") with open(file_path, "w") as f: f.write("A\nB\nC\n") - # Create modifier - modifier = ItemInjection.from_file(file_path, seed=42) + # Create transform + transform = SpuriousTextInjection( + text_key="text", + source=file_path, + location="random", + token_proportion=0.5, + seed=42, + ) - text, label = modifier("original text", 0) - # The exact assertion depends on how your ItemInjection modifies the text. - # Here's a generic check: - assert isinstance(text, str) - assert "TEST" in text or any(x in text for x in ["A", "B", "C"]) + sample = {"text": "original text", "label": 0} + output = transform(sample) + + assert isinstance(output["text"], str) + assert output["text"] != "original text" + assert any(tok in output["text"] for tok in ["A", "B", "C"]) @pytest.mark.unit -def test_item_injection_is_deterministic_with_seed(): +def test_spurious_text_injection_is_deterministic_with_seed(): with tempfile.TemporaryDirectory() as tmpdir: - file_path = os.path.join(tmpdir, "items.txt") + file_path = os.path.join(tmpdir, "tokens.txt") with open(file_path, "w") as f: f.write("X\nY\nZ\n") - m1 = ItemInjection.from_file(file_path, seed=42) - m2 = ItemInjection.from_file(file_path, seed=42) + # Create two transforms with the same seed + t1 = SpuriousTextInjection( + text_key="text", + source=file_path, + location="end", + token_proportion=0.5, + seed=123, + ) + t2 = SpuriousTextInjection( + text_key="text", + source=file_path, + location="end", + token_proportion=0.5, + seed=123, + ) - out1 = [m1("base text", 1)[0] for _ in range(10)] - out2 = [m2("base text", 1)[0] for _ in range(10)] + sample = {"text": "base text", "label": 1} + outputs1 = [t1(sample)["text"] for _ in range(5)] + outputs2 = [t2(sample)["text"] for _ in range(5)] - assert out1 == out2 + assert outputs1 == outputs2, "Should produce identical results with same seed" diff --git a/stable_pretraining/tests/unit/test_html_injection.py b/stable_pretraining/tests/unit/test_html_injection.py index 22646e44..38e69487 100644 --- a/stable_pretraining/tests/unit/test_html_injection.py +++ b/stable_pretraining/tests/unit/test_html_injection.py @@ -3,201 +3,64 @@ @pytest.mark.unit -def test_html_injection_proportion(tmp_path): - # Create a dummy tag file with 3 full tag pairs +def test_html_injection_deterministic_same_idx(tmp_path): tag_path = tmp_path / "tags.txt" - tag_path.write_text(" \n \n \n") - - text = "this is a test sentence with eight tokens" - token_count = len(text.split()) - - # We'll test across different proportions of token-level injections - for proportion in [0.1, 0.25, 0.5, 0.75, 1.0]: - modifier = HTMLInjection.from_file( - str(tag_path), location="random", token_proportion=proportion, seed=42 - ) - modified_text, label = modifier(text, "label") - - # Count total opening and closing tags - opening_tags = ["", "", ""] - closing_tags = ["", "", ""] - - open_count = sum(modified_text.count(tag) for tag in opening_tags) - close_count = sum(modified_text.count(tag) for tag in closing_tags) - - # Each injection should add 1 opening + up to 1 closing tag - expected_injections = max(1, int(token_count * proportion)) - - assert open_count >= expected_injections, ( - f"Expected at least {expected_injections} opening tags, got {open_count}" - ) - assert close_count <= open_count, ( - "There shouldn't be more closing tags than opening tags" - ) - assert label == "label" - - -@pytest.mark.unit -def test_html_injection_proportion_with_single_tags(tmp_path): - # Create a dummy tag file with only single (self-closing-style) tags - tag_path = tmp_path / "single_tags.txt" - tag_path.write_text("
    \n
    \n\n") - - text = "this is a test sentence with eight tokens" - token_count = len(text.split()) - - for proportion in [0.1, 0.25, 0.5, 0.75, 1.0]: - modifier = HTMLInjection.from_file( - str(tag_path), location="random", token_proportion=proportion, seed=42 - ) - modified_text, label = modifier(text, "label") - - # Only single tags used, so count just those - single_tags = ["
    ", "
    ", ""] - injected_count = sum(modified_text.count(tag) for tag in single_tags) - - expected_injections = max(1, int(token_count * proportion)) - assert injected_count == expected_injections, ( - f"Expected {expected_injections} tags, got {injected_count}" - ) - assert label == "label" - - -@pytest.mark.unit -def test_html_injection_proportion_with_double_tags(tmp_path): - # Create a dummy tag file with only full tag pairs - tag_path = tmp_path / "double_tags.txt" - tag_path.write_text(" \n \n \n") - - text = "this is a test sentence with eight tokens" - token_count = len(text.split()) - - for proportion in [0.1, 0.25, 0.5, 0.75, 1.0]: - modifier = HTMLInjection.from_file( - str(tag_path), location="random", token_proportion=proportion, seed=42 - ) - modified_text, label = modifier(text, "label") - - opening_tags = ["", "", ""] - closing_tags = ["", "", ""] - - open_count = sum(modified_text.count(tag) for tag in opening_tags) - close_count = sum(modified_text.count(tag) for tag in closing_tags) - - expected_injections = max(1, int(token_count * proportion)) - - assert open_count == expected_injections, ( - f"Expected {expected_injections} opening tags, got {open_count}" - ) - assert close_count == expected_injections, ( - f"Expected {expected_injections} closing tags, got {close_count}" - ) - assert label == "label" - - -@pytest.mark.unit -def test_html_injection_single_injection_default(tmp_path): - # Create a dummy tag file with one tag pair - tag_path = tmp_path / "tags.txt" - tag_path.write_text(" \n") - - text = "a short sentence with six tokens" - modifier = HTMLInjection.from_file(str(tag_path), location="random", seed=42) + tag_path.write_text(" \n") - modified_text, label = modifier(text, "label") + # Create deterministic injection transform + modifier = HTMLInjection( + file_path=str(tag_path), location="random", token_proportion=0.5, seed=123 + ) - # Expect exactly one opening tag and at most one closing tag - opening_tag = "" - closing_tag = "" + # Two samples with the same idx should yield identical results + sample1 = {"text": "consistent sample text", "label": "lbl", "idx": 10} + sample2 = {"text": "consistent sample text", "label": "lbl", "idx": 10} - open_count = modified_text.count(opening_tag) - close_count = modified_text.count(closing_tag) + out1 = modifier(sample1) + out2 = modifier(sample2) - assert open_count == 1, f"Expected exactly one opening tag, got {open_count}" - assert close_count <= 1, f"Expected at most one closing tag, got {close_count}" - assert label == "label" + assert out1["text"] == out2["text"], "Same idx should produce identical injection" + assert out1["label"] == out2["label"] @pytest.mark.unit -def test_html_injection_location_beginning(tmp_path): +def test_html_injection_deterministic_different_idx(tmp_path): tag_path = tmp_path / "tags.txt" tag_path.write_text(" \n") - text = "sample sentence" - - modifier = HTMLInjection.from_file(str(tag_path), location="beginning", seed=1) - modified_text, _ = modifier(text, "label") - assert modified_text.startswith(""), "Opening tag should be at the beginning" - -@pytest.mark.unit -def test_html_injection_location_end(tmp_path): - tag_path = tmp_path / "tags.txt" - tag_path.write_text(" \n") - text = "another sample" - - modifier = HTMLInjection.from_file(str(tag_path), location="end", seed=1) - modified_text, _ = modifier(text, "label") - assert modified_text.endswith("
    ") or "" in modified_text, ( - "Tag should be appended at end" + modifier = HTMLInjection( + file_path=str(tag_path), location="random", token_proportion=0.5, seed=123 ) + # Two samples with different idx should yield different results + sample1 = {"text": "different sample text", "label": "lbl", "idx": 1} + sample2 = {"text": "different sample text", "label": "lbl", "idx": 2} -@pytest.mark.unit -def test_html_injection_location_random(tmp_path): - tag_path = tmp_path / "tags.txt" - tag_path.write_text(" \n") - text = "tokens in various spots" - - modifier = HTMLInjection.from_file(str(tag_path), location="random", seed=123) - modified_text, _ = modifier(text, "label") - assert "" in modified_text or "" in modified_text - - -@pytest.mark.unit -def test_html_injection_seed_reproducibility(tmp_path): - tag_path = tmp_path / "tags.txt" - tag_path.write_text(" \n") + out1 = modifier(sample1) + out2 = modifier(sample2) - text = "reproducibility is key" - mod1 = HTMLInjection.from_file( - str(tag_path), location="random", token_proportion=0.5, seed=42 - ) - mod2 = HTMLInjection.from_file( - str(tag_path), location="random", token_proportion=0.5, seed=42 + assert out1["text"] != out2["text"], ( + "Different idx should produce different injections" ) - out1, _ = mod1(text, "label") - out2, _ = mod2(text, "label") - assert out1 == out2 - @pytest.mark.unit -def test_html_injection_different_seeds(tmp_path): +def test_html_injection_deterministic_reproducibility_across_runs(tmp_path): tag_path = tmp_path / "tags.txt" - tag_path.write_text(" \n") - text = "inject differently based on seed" + tag_path.write_text("

    \n") - mod1 = HTMLInjection.from_file( - str(tag_path), location="random", token_proportion=0.5, seed=1 + sample = {"text": "check reproducibility", "label": "lbl", "idx": 42} + + modifier1 = HTMLInjection( + file_path=str(tag_path), location="random", token_proportion=0.5, seed=777 ) - mod2 = HTMLInjection.from_file( - str(tag_path), location="random", token_proportion=0.5, seed=2 + modifier2 = HTMLInjection( + file_path=str(tag_path), location="random", token_proportion=0.5, seed=777 ) - out1, _ = mod1(text, "label") - out2, _ = mod2(text, "label") - assert out1 != out2, "Different seeds should yield different outputs" - - -@pytest.mark.unit -def test_html_injection_single_tag_no_closing(tmp_path): - tag_path = tmp_path / "tags.txt" - tag_path.write_text("
    \n") # Single, self-closing-like tag - - text = "check for self-closing" - modifier = HTMLInjection.from_file(str(tag_path), location="end", seed=99) - modified_text, _ = modifier(text, "label") + out1 = modifier1(sample) + out2 = modifier2(sample) - assert "
    " in modified_text and ""], location="beginning", seed=42) - modified_text, _ = modifier(text, "label") - assert modified_text.startswith(""), "Injection should be at the beginning" + out1 = transform(sample1) + out2 = transform(sample2) + + assert out1["text"] != out2["text"], "Different idx should yield different outputs" @pytest.mark.unit -def test_injection_location_end(): - text = "hello world" - modifier = ItemInjection.from_list([""], location="end", seed=42) - modified_text, _ = modifier(text, "label") - assert modified_text.endswith(""), "Injection should be at the end" +def test_spurious_text_injection_reproducibility_across_runs(tmp_path): + src_path = tmp_path / "tokens.txt" + src_path.write_text("A\nB\nC\n") + text = "reproducibility test for spurious injection" + sample = {"text": text, "label": "C", "idx": 42} -@pytest.mark.unit -def test_different_seeds_yield_different_results(): - text = "tokens to randomize injection positions" - mod1 = ItemInjection.from_list( - [""], token_proportion=0.5, location="random", seed=1 + transform1 = SpuriousTextInjection( + text_key="text", + source=str(src_path), + location="end", + token_proportion=0.25, + seed=999, ) - mod2 = ItemInjection.from_list( - [""], token_proportion=0.5, location="random", seed=2 + transform2 = SpuriousTextInjection( + text_key="text", + source=str(src_path), + location="end", + token_proportion=0.25, + seed=999, ) - text1, _ = mod1(text, "label") - text2, _ = mod2(text, "label") + out1 = transform1(sample) + out2 = transform2(sample) - assert text1 != text2, "Different seeds should yield different injection positions" + assert out1["text"] == out2["text"], ( + "Same seed and idx should yield identical injection" + ) + + +@pytest.mark.unit +def test_spurious_text_injection_with_addsampleidx(tmp_path): + src_path = tmp_path / "spurious.txt" + src_path.write_text("NOISE\nTAG\n") + + add_idx = AddSampleIdx() + transform = SpuriousTextInjection( + text_key="text", + source=str(src_path), + location="beginning", + token_proportion=0.3, + seed=777, + ) + + sample = {"text": "verify deterministic pipeline", "label": "Y"} + sample = add_idx(sample) + out = transform(sample) + + assert "idx" in sample, "AddSampleIdx should add an 'idx' field" + assert any(t in out["text"] for t in ["NOISE", "TAG"]), ( + "Should inject spurious token" + ) From eeabb83982394808e911bbe310bc0ba9d7da60f9 Mon Sep 17 00:00:00 2001 From: "marcel_mateos_salles@brown.edu" Date: Thu, 23 Oct 2025 11:33:18 -0400 Subject: [PATCH 17/27] changed some parameter names --- .../tests/unit/test_file_based_sampling.py | 6 ++--- .../tests/unit/test_html_injection.py | 24 +++++++++++++++---- .../tests/unit/test_item_injection.py | 10 ++++---- 3 files changed, 28 insertions(+), 12 deletions(-) diff --git a/stable_pretraining/tests/unit/test_file_based_sampling.py b/stable_pretraining/tests/unit/test_file_based_sampling.py index 812c8760..c9b1ddfc 100644 --- a/stable_pretraining/tests/unit/test_file_based_sampling.py +++ b/stable_pretraining/tests/unit/test_file_based_sampling.py @@ -30,7 +30,7 @@ def test_spurious_text_injection_reads_from_file_and_injects_correctly(): # Create transform transform = SpuriousTextInjection( text_key="text", - source=file_path, + file_path=file_path, location="random", token_proportion=0.5, seed=42, @@ -54,14 +54,14 @@ def test_spurious_text_injection_is_deterministic_with_seed(): # Create two transforms with the same seed t1 = SpuriousTextInjection( text_key="text", - source=file_path, + file_path=file_path, location="end", token_proportion=0.5, seed=123, ) t2 = SpuriousTextInjection( text_key="text", - source=file_path, + file_path=file_path, location="end", token_proportion=0.5, seed=123, diff --git a/stable_pretraining/tests/unit/test_html_injection.py b/stable_pretraining/tests/unit/test_html_injection.py index 38e69487..1700fc3f 100644 --- a/stable_pretraining/tests/unit/test_html_injection.py +++ b/stable_pretraining/tests/unit/test_html_injection.py @@ -9,7 +9,11 @@ def test_html_injection_deterministic_same_idx(tmp_path): # Create deterministic injection transform modifier = HTMLInjection( - file_path=str(tag_path), location="random", token_proportion=0.5, seed=123 + file_path=str(tag_path), + text_key="text", + location="random", + token_proportion=0.5, + seed=123, ) # Two samples with the same idx should yield identical results @@ -29,7 +33,11 @@ def test_html_injection_deterministic_different_idx(tmp_path): tag_path.write_text(" \n") modifier = HTMLInjection( - file_path=str(tag_path), location="random", token_proportion=0.5, seed=123 + file_path=str(tag_path), + text_key="text", + location="random", + token_proportion=0.5, + seed=123, ) # Two samples with different idx should yield different results @@ -52,10 +60,18 @@ def test_html_injection_deterministic_reproducibility_across_runs(tmp_path): sample = {"text": "check reproducibility", "label": "lbl", "idx": 42} modifier1 = HTMLInjection( - file_path=str(tag_path), location="random", token_proportion=0.5, seed=777 + file_path=str(tag_path), + text_key="text", + location="random", + token_proportion=0.5, + seed=777, ) modifier2 = HTMLInjection( - file_path=str(tag_path), location="random", token_proportion=0.5, seed=777 + file_path=str(tag_path), + text_key="text", + location="random", + token_proportion=0.5, + seed=777, ) out1 = modifier1(sample) diff --git a/stable_pretraining/tests/unit/test_item_injection.py b/stable_pretraining/tests/unit/test_item_injection.py index face1246..82e4e819 100644 --- a/stable_pretraining/tests/unit/test_item_injection.py +++ b/stable_pretraining/tests/unit/test_item_injection.py @@ -11,7 +11,7 @@ def test_spurious_text_injection_deterministic_same_idx(tmp_path): text = "deterministic spurious injection test" transform = SpuriousTextInjection( text_key="text", - source=str(src_path), + file_path=str(src_path), location="random", token_proportion=0.5, seed=123, @@ -35,7 +35,7 @@ def test_spurious_text_injection_deterministic_different_idx(tmp_path): text = "check for idx-dependent difference" transform = SpuriousTextInjection( text_key="text", - source=str(src_path), + file_path=str(src_path), location="random", token_proportion=0.5, seed=321, @@ -60,14 +60,14 @@ def test_spurious_text_injection_reproducibility_across_runs(tmp_path): transform1 = SpuriousTextInjection( text_key="text", - source=str(src_path), + file_path=str(src_path), location="end", token_proportion=0.25, seed=999, ) transform2 = SpuriousTextInjection( text_key="text", - source=str(src_path), + file_path=str(src_path), location="end", token_proportion=0.25, seed=999, @@ -89,7 +89,7 @@ def test_spurious_text_injection_with_addsampleidx(tmp_path): add_idx = AddSampleIdx() transform = SpuriousTextInjection( text_key="text", - source=str(src_path), + file_path=str(src_path), location="beginning", token_proportion=0.3, seed=777, From 65fdf42578d0d7598ae9638dbb0c78d3cbd2c8eb Mon Sep 17 00:00:00 2001 From: "marcel_mateos_salles@brown.edu" Date: Thu, 23 Oct 2025 11:53:12 -0400 Subject: [PATCH 18/27] trying to fix one of two errors --- .../tests/unit/test_item_injection.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/stable_pretraining/tests/unit/test_item_injection.py b/stable_pretraining/tests/unit/test_item_injection.py index 82e4e819..c66c0fc4 100644 --- a/stable_pretraining/tests/unit/test_item_injection.py +++ b/stable_pretraining/tests/unit/test_item_injection.py @@ -8,8 +8,11 @@ def test_spurious_text_injection_deterministic_same_idx(tmp_path): src_path = tmp_path / "spurious.txt" src_path.write_text("RED\nGREEN\nBLUE\n") + src_path2 = tmp_path / "spurious2.txt" + src_path2.write_text("RED\nGREEN\nBLUE\n") + text = "deterministic spurious injection test" - transform = SpuriousTextInjection( + transform1 = SpuriousTextInjection( text_key="text", file_path=str(src_path), location="random", @@ -17,11 +20,19 @@ def test_spurious_text_injection_deterministic_same_idx(tmp_path): seed=123, ) + transform2 = SpuriousTextInjection( + text_key="text", + file_path=str(src_path2), + location="random", + token_proportion=0.5, + seed=123, + ) + sample1 = {"text": text, "label": "A", "idx": 5} sample2 = {"text": text, "label": "A", "idx": 5} - out1 = transform(sample1) - out2 = transform(sample2) + out1 = transform1(sample1) + out2 = transform2(sample2) assert out1["text"] == out2["text"], "Same idx should produce identical injection" assert out1["label"] == out2["label"] From f6b0ee87c49cd0611262b6b0bcd794db2d495717 Mon Sep 17 00:00:00 2001 From: "marcel_mateos_salles@brown.edu" Date: Thu, 23 Oct 2025 14:58:22 -0400 Subject: [PATCH 19/27] updated the test --- .../tests/unit/test_file_based_sampling.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/stable_pretraining/tests/unit/test_file_based_sampling.py b/stable_pretraining/tests/unit/test_file_based_sampling.py index c9b1ddfc..f0dd46a0 100644 --- a/stable_pretraining/tests/unit/test_file_based_sampling.py +++ b/stable_pretraining/tests/unit/test_file_based_sampling.py @@ -46,12 +46,12 @@ def test_spurious_text_injection_reads_from_file_and_injects_correctly(): @pytest.mark.unit def test_spurious_text_injection_is_deterministic_with_seed(): + # Create two transforms with the same seed with tempfile.TemporaryDirectory() as tmpdir: file_path = os.path.join(tmpdir, "tokens.txt") with open(file_path, "w") as f: f.write("X\nY\nZ\n") - # Create two transforms with the same seed t1 = SpuriousTextInjection( text_key="text", file_path=file_path, @@ -59,6 +59,11 @@ def test_spurious_text_injection_is_deterministic_with_seed(): token_proportion=0.5, seed=123, ) + with tempfile.TemporaryDirectory() as tmpdir: + file_path = os.path.join(tmpdir, "tokens.txt") + with open(file_path, "w") as f2: + f2.write("X\nY\nZ\n") + t2 = SpuriousTextInjection( text_key="text", file_path=file_path, @@ -67,8 +72,8 @@ def test_spurious_text_injection_is_deterministic_with_seed(): seed=123, ) - sample = {"text": "base text", "label": 1} - outputs1 = [t1(sample)["text"] for _ in range(5)] - outputs2 = [t2(sample)["text"] for _ in range(5)] + sample = {"text": "base text", "label": 1} + outputs1 = [t1(sample)["text"] for _ in range(5)] + outputs2 = [t2(sample)["text"] for _ in range(5)] - assert outputs1 == outputs2, "Should produce identical results with same seed" + assert outputs1 == outputs2, "Should produce identical results with same seed" From dc658e972ade5747bb2b1f5177e0006eca88d59e Mon Sep 17 00:00:00 2001 From: "marcel_mateos_salles@brown.edu" Date: Thu, 23 Oct 2025 15:09:10 -0400 Subject: [PATCH 20/27] updating deterministic injectoin --- stable_pretraining/data/transforms.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/stable_pretraining/data/transforms.py b/stable_pretraining/data/transforms.py index 9ad910b3..f0d651c4 100644 --- a/stable_pretraining/data/transforms.py +++ b/stable_pretraining/data/transforms.py @@ -1072,8 +1072,7 @@ def __init__( self.text_key = text_key self.location = location self.token_proportion = token_proportion - self.rng = random.Random(seed) - + self.base_seed = seed with open(file_path, "r", encoding="utf-8") as f: self.items = [line.strip() for line in f if line.strip()] @@ -1081,11 +1080,11 @@ def __init__( assert 0 <= self.token_proportion <= 1, "token_proportion must be in [0, 1]" assert self.location in {"beginning", "random", "end"} - def _inject(self, text: str) -> str: + def _inject(self, text: str, rng: random.Random) -> str: words = text.split() num_tokens = len(words) num_to_inject = max(1, int(num_tokens * self.token_proportion)) - injections = [self.rng.choice(self.items) for _ in range(num_to_inject)] + injections = [rng.choice(self.items) for _ in range(num_to_inject)] if self.location == "beginning": words = injections + words @@ -1093,13 +1092,21 @@ def _inject(self, text: str) -> str: words = words + injections elif self.location == "random": for inj in injections: - pos = self.rng.randint(0, len(words)) + pos = rng.randint(0, len(words)) words.insert(pos, inj) return " ".join(words) def __call__(self, x: dict) -> dict: text = x[self.text_key] - x[self.text_key] = self._inject(text) + + # Deterministic RNG per call + if self.base_seed is not None: + # Derive a deterministic RNG for this call + rng = random.Random(self.base_seed) + else: + rng = random.Random() + + x[self.text_key] = self._inject(text, rng) return x From 039896cff92714ca87322f61326f2b2d0d0b379a Mon Sep 17 00:00:00 2001 From: "marcel_mateos_salles@brown.edu" Date: Thu, 23 Oct 2025 15:24:08 -0400 Subject: [PATCH 21/27] minor update --- .../tests/unit/test_file_based_sampling.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/stable_pretraining/tests/unit/test_file_based_sampling.py b/stable_pretraining/tests/unit/test_file_based_sampling.py index f0dd46a0..67171dc9 100644 --- a/stable_pretraining/tests/unit/test_file_based_sampling.py +++ b/stable_pretraining/tests/unit/test_file_based_sampling.py @@ -46,12 +46,13 @@ def test_spurious_text_injection_reads_from_file_and_injects_correctly(): @pytest.mark.unit def test_spurious_text_injection_is_deterministic_with_seed(): - # Create two transforms with the same seed + # Create one file with spurious tokens with tempfile.TemporaryDirectory() as tmpdir: file_path = os.path.join(tmpdir, "tokens.txt") with open(file_path, "w") as f: f.write("X\nY\nZ\n") + # Create two transforms reading from the same file t1 = SpuriousTextInjection( text_key="text", file_path=file_path, @@ -59,11 +60,6 @@ def test_spurious_text_injection_is_deterministic_with_seed(): token_proportion=0.5, seed=123, ) - with tempfile.TemporaryDirectory() as tmpdir: - file_path = os.path.join(tmpdir, "tokens.txt") - with open(file_path, "w") as f2: - f2.write("X\nY\nZ\n") - t2 = SpuriousTextInjection( text_key="text", file_path=file_path, @@ -72,8 +68,8 @@ def test_spurious_text_injection_is_deterministic_with_seed(): seed=123, ) - sample = {"text": "base text", "label": 1} - outputs1 = [t1(sample)["text"] for _ in range(5)] - outputs2 = [t2(sample)["text"] for _ in range(5)] + sample = {"text": "base text", "label": 1} + outputs1 = [t1(sample)["text"] for _ in range(5)] + outputs2 = [t2(sample)["text"] for _ in range(5)] - assert outputs1 == outputs2, "Should produce identical results with same seed" + assert outputs1 == outputs2, "Should produce identical results with same seed" From 3bbe8bbdd1bc5ee0edb29b94ed0ee349865e6e7e Mon Sep 17 00:00:00 2001 From: "marcel_mateos_salles@brown.edu" Date: Thu, 23 Oct 2025 15:31:00 -0400 Subject: [PATCH 22/27] hopefully final fix --- stable_pretraining/data/transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stable_pretraining/data/transforms.py b/stable_pretraining/data/transforms.py index f0d651c4..0a7358b6 100644 --- a/stable_pretraining/data/transforms.py +++ b/stable_pretraining/data/transforms.py @@ -1101,8 +1101,8 @@ def __call__(self, x: dict) -> dict: # Deterministic RNG per call if self.base_seed is not None: - # Derive a deterministic RNG for this call - rng = random.Random(self.base_seed) + idx = x.get("idx", 0) + rng = random.Random(self.base_seed + idx) else: rng = random.Random() From fc2ca1264bd7f82a9bb3180988f854c2783b6e67 Mon Sep 17 00:00:00 2001 From: "marcel_mateos_salles@brown.edu" Date: Thu, 23 Oct 2025 15:38:41 -0400 Subject: [PATCH 23/27] please fix --- stable_pretraining/data/transforms.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/stable_pretraining/data/transforms.py b/stable_pretraining/data/transforms.py index 0a7358b6..6ed19b0d 100644 --- a/stable_pretraining/data/transforms.py +++ b/stable_pretraining/data/transforms.py @@ -1073,6 +1073,9 @@ def __init__( self.location = location self.token_proportion = token_proportion self.base_seed = seed + # store RNG per idx + self.rngs = {} + with open(file_path, "r", encoding="utf-8") as f: self.items = [line.strip() for line in f if line.strip()] @@ -1080,6 +1083,12 @@ def __init__( assert 0 <= self.token_proportion <= 1, "token_proportion must be in [0, 1]" assert self.location in {"beginning", "random", "end"} + def _get_rng(self, idx): + if idx not in self.rngs: + seed = self.base_seed + idx if self.base_seed is not None else None + self.rngs[idx] = random.Random(seed) + return self.rngs[idx] + def _inject(self, text: str, rng: random.Random) -> str: words = text.split() num_tokens = len(words) @@ -1102,7 +1111,7 @@ def __call__(self, x: dict) -> dict: # Deterministic RNG per call if self.base_seed is not None: idx = x.get("idx", 0) - rng = random.Random(self.base_seed + idx) + rng = self._get_rng(idx) else: rng = random.Random() From 1b4416289b2f4d5d95a4feecbde1ae3833ce2698 Mon Sep 17 00:00:00 2001 From: "marcel_mateos_salles@brown.edu" Date: Thu, 23 Oct 2025 15:41:09 -0400 Subject: [PATCH 24/27] final fix --- stable_pretraining/tests/unit/test_file_based_sampling.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/stable_pretraining/tests/unit/test_file_based_sampling.py b/stable_pretraining/tests/unit/test_file_based_sampling.py index 67171dc9..b94699fc 100644 --- a/stable_pretraining/tests/unit/test_file_based_sampling.py +++ b/stable_pretraining/tests/unit/test_file_based_sampling.py @@ -68,8 +68,9 @@ def test_spurious_text_injection_is_deterministic_with_seed(): seed=123, ) - sample = {"text": "base text", "label": 1} - outputs1 = [t1(sample)["text"] for _ in range(5)] - outputs2 = [t2(sample)["text"] for _ in range(5)] + sample1 = {"text": "base text", "label": 1} + sample2 = {"text": "base text", "label": 1} + outputs1 = [t1(sample1)["text"] for _ in range(5)] + outputs2 = [t2(sample2)["text"] for _ in range(5)] assert outputs1 == outputs2, "Should produce identical results with same seed" From ec4156eda40182e65c75d4a8723e5e518611a446 Mon Sep 17 00:00:00 2001 From: "marcel_mateos_salles@brown.edu" Date: Thu, 23 Oct 2025 15:51:37 -0400 Subject: [PATCH 25/27] included the spurious vision transforms --- stable_pretraining/data/transforms.py | 189 ++++++++++++++++++ .../tests/unit/test_transforms.py | 91 +++++++++ 2 files changed, 280 insertions(+) diff --git a/stable_pretraining/data/transforms.py b/stable_pretraining/data/transforms.py index 6ed19b0d..ee46670e 100644 --- a/stable_pretraining/data/transforms.py +++ b/stable_pretraining/data/transforms.py @@ -16,6 +16,8 @@ from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2._utils import query_chw +from torchvision.io import read_image +from torchvision.transforms.functional import resize from stable_pretraining.data.masking import multi_block_mask @@ -971,6 +973,9 @@ def __call__(self, x): # sample[self.new_key] = sample[self.label_key] # return sample +# ------------------------------------------------------------------------------------------------------------- +# Spurious Text Transforms + class AddSampleIdx(Transform): """Add an "idx" key each sample to allow for deterministic injection.""" @@ -1250,3 +1255,187 @@ def __call__(self, x: Dict[str, Any]) -> Dict[str, Any]: x[self.text_key] = self._inject_at_level(text, self.level, rng) return x + + +# ------------------------------------------------------------------------------------------------------------- +# Spurious Image Transforms + + +class AddPatch(Transform): + """Add a solid color patch to an image at a fixed position. + + Args: + patch_size (float): Fraction of image width/height for the patch (0 < patch_size ≤ 1). + color (Tuple[float, float, float]): RGB values in [0, 1]. + position (str): Where to place the patch: 'top_left_corner', 'top_right_corner', + 'bottom_left_corner', 'bottom_right_corner', 'center'. + """ + + def __init__( + self, + patch_size: float = 0.1, + color: Tuple[float, float, float] = (1.0, 0.0, 0.0), + position: str = "bottom_right_corner", + ): + super().__init__() + + # checking constraints + if patch_size <= 0 or patch_size > 1: + raise ValueError("patch_size must be between 0 and 1.") + + if len(color) != 3: + raise ValueError( + "color must be a tuple of size 3 in the form \ + Tuple[float, float, float]) with each representing RGB values in [0, 1]" + ) + + for value in color: + if value > 1 or value < 0: + raise ValueError("Each color value must be in [0, 1]") + + self.patch_size = patch_size + self.color = color + self.position = position + + def __call__(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + img = self.nested_get(x, "image") + _, H, W = img.shape + + patch_h = int(H * self.patch_size) + patch_w = int(W * self.patch_size) + + # Create a colored patch + patch = torch.zeros((3, patch_h, patch_w), device=img.device) + patch[0] = self.color[0] + patch[1] = self.color[1] + patch[2] = self.color[2] + + img = img.clone() + if self.position == "top_left_corner": + img[:, :patch_h, :patch_w] = patch + elif self.position == "top_right_corner": + img[:, :patch_h, -patch_w:] = patch + elif self.position == "bottom_left_corner": + img[:, -patch_h:, :patch_w] = patch + elif self.position == "bottom_right_corner": + img[:, -patch_h:, -patch_w:] = patch + elif self.position == "center": + center_y, center_x = H // 2, W // 2 + img[ + :, + center_y - patch_h // 2 : center_y + patch_h // 2, + center_x - patch_w // 2 : center_x + patch_w // 2, + ] = patch + else: + raise ValueError( + f"Invalid position: {self.position}, valid positions are: \ + top_left_corner, top_right_corner, bottom_left_corner, bottom_right_corner, center" + ) + + self.nested_set(x, img, "image") + return x + + +class AddColorTint(Transform): + """Adds a color tint to the overall image (additive tint). + + Args: + tint (Tuple[float, float, float]): RGB representation of the tint that will be applied to the overall image + alpha (Float): mixing ratio for how much to blend the new color with the existing image + """ + + def __init__( + self, tint: Tuple[float, float, float] = (1.0, 0.8, 0.8), alpha: float = 0.3 + ): + super().__init__() + self.tint = torch.tensor(tint).view(3, 1, 1) + self.alpha = alpha + + def __call__(self, x): + img = self.nested_get(x, "image") + img = torch.clamp(img * (1 - self.alpha) + self.tint * self.alpha, 0, 1) + self.nested_set(x, img, "image") + return x + + +class AddBorder(Transform): + """Adds a border around an image. + + Args: + thickness (Float): how thick the border around the image will be + color (Tuple[float, float, float]): RGB representation of the color of the border + """ + + def __init__( + self, thickness: float = 0.05, color: Tuple[float, float, float] = (0, 1, 0) + ): + super().__init__() + self.thickness = thickness + self.color = color + + def __call__(self, x): + img = self.nested_get(x, "image").clone() + _, H, W = img.shape + + # scale to match image size + t = int(min(H, W) * self.thickness) + color_tensor = torch.tensor(self.color, device=img.device).view(3, 1, 1) + + img[:, :t, :] = color_tensor + img[:, -t:, :] = color_tensor + img[:, :, :t] = color_tensor + img[:, :, -t:] = color_tensor + self.nested_set(x, img, "image") + + return x + + +class AddWatermark(Transform): + """Overlay another image (logo, emoji, etc.) onto the base image. + + Args: + watermark_path (str): Path to the watermark image (e.g. 'smile.png'). + size (float): Fraction of base image size to scale watermark. + position (str): One of ['top_left', 'top_right', 'bottom_left', 'bottom_right', 'center']. + alpha (float): Opacity of watermark (0-1). + """ + + def __init__(self, watermark_path, size=0.2, position="bottom_right", alpha=0.8): + super().__init__() + # [C,H,W] tensor in [0,1] + self.watermark = read_image(watermark_path).float() / 255.0 + self.size = size + self.position = position + self.alpha = alpha + + def __call__(self, x): + img = self.nested_get(x, "image").clone() + _, H, W = img.shape + + # Resize watermark + w_h, w_w = self.watermark.shape[1:] + target_h = int(H * self.size) + target_w = int(w_w / w_h * target_h) + wm = resize(self.watermark, [target_h, target_w]) + + # Compute position + if self.position == "top_left": + y0, x0 = 0, 0 + elif self.position == "top_right": + y0, x0 = 0, W - target_w + elif self.position == "bottom_left": + y0, x0 = H - target_h, 0 + elif self.position == "bottom_right": + y0, x0 = H - target_h, W - target_w + elif self.position == "center": + y0, x0 = (H - target_h) // 2, (W - target_w) // 2 + else: + raise ValueError(f"Unknown position: {self.position}") + + background_region = img[:, y0 : y0 + target_h, x0 : x0 + target_w] + img[:, y0 : y0 + target_h, x0 : x0 + target_w] = ( + background_region * (1 - self.alpha) + wm * self.alpha + ) + + self.nested_set(x, img, "image") + return x diff --git a/stable_pretraining/tests/unit/test_transforms.py b/stable_pretraining/tests/unit/test_transforms.py index 90d277f5..fc098095 100644 --- a/stable_pretraining/tests/unit/test_transforms.py +++ b/stable_pretraining/tests/unit/test_transforms.py @@ -91,3 +91,94 @@ def test_transform_params_initialization(self): for t in transforms_to_test: assert t is not None + + # --------------------------- + # Spurious correlation tests + # --------------------------- + + def test_add_sample_idx_transform(self): + """Test that AddSampleIdx correctly increments indices.""" + transform = transforms.AddSampleIdx() + x1 = {"image": torch.zeros(3, 32, 32)} + x2 = {"image": torch.zeros(3, 32, 32)} + out1 = transform(x1) + out2 = transform(x2) + assert out1["idx"] == 0 + assert out2["idx"] == 1 + + def test_add_patch_transform(self): + """Test that AddPatch overlays a colored patch.""" + img = torch.zeros(3, 32, 32) + data = {"image": img.clone()} + transform = transforms.AddPatch( + patch_size=0.25, color=(1.0, 0.0, 0.0), position="top_left_corner" + ) + result = transform(data) + # Top-left corner should now contain red pixels + patch_area = result["image"][:, :8, :8] + assert torch.allclose(patch_area[0], torch.ones_like(patch_area[0]), atol=1e-3) + assert torch.allclose( + patch_area[1:], torch.zeros_like(patch_area[1:]), atol=1e-3 + ) + + def test_add_color_tint_transform(self): + """Test AddColorTint applies an additive tint.""" + img = torch.zeros(3, 16, 16) + data = {"image": img} + transform = transforms.AddColorTint(tint=(1.0, 0.5, 0.5), alpha=0.5) + result = transform(data) + # Image should not be all zeros anymore + assert torch.any(result["image"] > 0) + + def test_add_border_transform(self): + """Test AddBorder draws a colored border.""" + img = torch.zeros(3, 20, 20) + data = {"image": img} + transform = transforms.AddBorder(thickness=0.1, color=(0, 1, 0)) + result = transform(data) + # Corners should have green (0,1,0) + assert torch.allclose(result["image"][1, 0, 0], torch.tensor(1.0), atol=1e-3) + assert torch.allclose(result["image"][0, 0, 0], torch.tensor(0.0), atol=1e-3) + + def test_add_watermark_transform(self, tmp_path): + """Test AddWatermark overlays another image.""" + # Create a dummy watermark (white square) + wm_path = tmp_path / "wm.png" + from torchvision.utils import save_image + + save_image(torch.ones(3, 8, 8), wm_path) + data = {"image": torch.zeros(3, 32, 32)} + transform = transforms.AddWatermark( + str(wm_path), size=0.25, position="center", alpha=1.0 + ) + result = transform(data) + # There should be a bright region in the center + center = result["image"][:, 12:20, 12:20] + assert torch.mean(center) > 0.5 + + def test_class_conditional_injector(self): + """Test ClassConditionalInjector applies transform to correct labels only.""" + base_transform = transforms.AddPatch(color=(0, 1, 0)) + injector = transforms.ClassConditionalInjector( + transformation=base_transform, + target_labels=[1], + proportion=1.0, + total_samples=5, + seed=42, + ) + + # Prepare samples with idx + label + samples = [ + {"image": torch.zeros(3, 16, 16), "label": torch.tensor(label), "idx": idx} + for idx, label in enumerate([0, 1, 1, 0, 1]) + ] + + outputs = [injector(s) for s in samples] + + # Check that only samples with label of 1 were modified + for s_in, s_out in zip(samples, outputs): + mean_pixel = s_out["image"].mean().item() + if s_in["label"] == 1: + assert mean_pixel > 0 # patch added + else: + assert mean_pixel == 0 # unchanged From f20c5688e4fd3d2cd81e938564f08a82728d5abf Mon Sep 17 00:00:00 2001 From: "marcel_mateos_salles@brown.edu" Date: Thu, 23 Oct 2025 16:01:58 -0400 Subject: [PATCH 26/27] small update to the releast.rst --- RELEASES.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/RELEASES.rst b/RELEASES.rst index 04bca67b..a76a30d1 100644 --- a/RELEASES.rst +++ b/RELEASES.rst @@ -15,4 +15,4 @@ Version 0.1 - RankMe, LiDAR metrics to monitor training. - Examples of extracting run data from WandB and utilizing it to create figures. - Fixed a bug in the logging functionality. -- Library for injecting spurious tokens into HuggingFace datasets (text). +- Library for injecting spurious tokens into HuggingFace datasets. From be6d10ba0eb4ac0af7e3f2d9be684bf2bd671f59 Mon Sep 17 00:00:00 2001 From: "marcel_mateos_salles@brown.edu" Date: Sun, 26 Oct 2025 19:05:47 -0400 Subject: [PATCH 27/27] fixing random error from pulling --- stable_pretraining/tests/unit/test_module.py | 1 + 1 file changed, 1 insertion(+) diff --git a/stable_pretraining/tests/unit/test_module.py b/stable_pretraining/tests/unit/test_module.py index ac8f2d83..7cc7bcc2 100644 --- a/stable_pretraining/tests/unit/test_module.py +++ b/stable_pretraining/tests/unit/test_module.py @@ -31,6 +31,7 @@ def test_module_initialization(): @pytest.mark.integration def test_module_integration(): """Integration test for the Module class with multiple optimizers. + trainer.fit() is called to ensure configure_optimizers work as expected. """ # Define simple backbone and projector