Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
8347b74
spurious correlation package
Oct 3, 2025
ed30672
fixed the unit tests
Oct 3, 2025
9728b31
updateing package for spurious corr visualization
Oct 4, 2025
762cc85
update release.rst
Oct 4, 2025
801a121
updated punctuation in release.rst
Oct 4, 2025
6a58a59
Merge branch 'main' into spurious-correlation
MarcelMatsal Oct 4, 2025
609b3f1
implementing suggestions and comments from Randall
Oct 13, 2025
fb9a430
Merge branch 'main' into spurious-correlation
MarcelMatsal Oct 13, 2025
e7e5d40
minor bug fixed from reformatting
Oct 13, 2025
4a5bdd9
fixing import errors
Oct 13, 2025
65a203e
removed the files for spurious text to huggingface
Oct 13, 2025
a3371a9
removed ds.store
Oct 17, 2025
fa62bc1
Merge branch 'main' into spurious-correlation
MarcelMatsal Oct 17, 2025
ed0a4e4
refactoring to make the code cleaner, updated tests as well
Oct 22, 2025
2ae1b46
Merge branch 'main' into spurious-correlation
MarcelMatsal Oct 22, 2025
9e52fb2
updating tests
Oct 22, 2025
e66b276
updating tests
Oct 22, 2025
3d3c9e5
updating tests
Oct 22, 2025
d74bddd
updating tests
Oct 22, 2025
35013db
further refactored code to make them all transformations
Oct 23, 2025
eeabb83
changed some parameter names
Oct 23, 2025
65fdf42
trying to fix one of two errors
Oct 23, 2025
f6b0ee8
updated the test
Oct 23, 2025
dc658e9
updating deterministic injectoin
Oct 23, 2025
039896c
minor update
Oct 23, 2025
3bbe8bb
hopefully final fix
Oct 23, 2025
fc2ca12
please fix
Oct 23, 2025
1b44162
final fix
Oct 23, 2025
ec4156e
included the spurious vision transforms
Oct 23, 2025
f20c568
small update to the releast.rst
Oct 23, 2025
a547960
Merge branch 'main' into spurious-correlation
MarcelMatsal Oct 23, 2025
462da69
Merge branch 'main' into spurious-correlation
MarcelMatsal Oct 26, 2025
be6d10b
fixing random error from pulling
Oct 26, 2025
bea5c27
Merge branch 'main' into spurious-correlation
MarcelMatsal Oct 30, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ Version 0.1
- RankMe, LiDAR metrics to monitor training.
- Examples of extracting run data from WandB and utilizing it to create figures.
- Fixed a bug in the logging functionality.
- Library for injecting spurious tokens into HuggingFace datasets (text).
282 changes: 282 additions & 0 deletions examples/sample_spurious_injection_execution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
"""Demonstration of the spurious_corr library capabilities."""

from stable_pretraining.data.spurious_corr.modifiers import (
ItemInjection,
HTMLInjection,
CompositeModifier,
)
from stable_pretraining.data.spurious_corr.generators import SpuriousDateGenerator
from stable_pretraining.data.spurious_corr.utils import (
pretty_print,
pretty_print_dataset,
highlight_from_file,
highlight_from_list,
highlight_html,
highlight_dates,
)
from stable_pretraining.data.spurious_corr.transform import spurious_transform
from datasets import load_dataset


def print_section(title):
"""Print a formatted section header."""
print("\n" + "=" * 60)
print(f" {title}")
print("=" * 60)


def example_1_basic_date_injection():
"""Example 1: Basic date injection at different locations."""
print_section("Example 1: Date Injection with SpuriousDateGenerator")
text = "Machine learning models require careful evaluation and testing procedures."
print(f"\nOriginal text: '{text}'\n")
print("-" * 50)

# Create date generator
date_gen = SpuriousDateGenerator(year_range=(1900, 2100), seed=42)

# Example 1a: Inject at beginning
modifier_start = ItemInjection.from_function(
injection_func=date_gen, location="beginning", token_proportion=0.4, seed=42
)

modified_text, _ = modifier_start(text, 1)
print("1a. Date injection at BEGINNING:")
pretty_print(modified_text, highlight_dates)
print("-" * 50)

# Example 1b: Inject at end
modifier_end = ItemInjection.from_function(
injection_func=date_gen, location="end", token_proportion=0.4, seed=43
)

modified_text, _ = modifier_end(text, 1)
print("1b. Date injection at END:")
pretty_print(modified_text, highlight_dates)
print("-" * 50)

# Example 1c: Inject at random locations
modifier_random = ItemInjection.from_function(
injection_func=date_gen, location="random", token_proportion=0.4, seed=44
)

modified_text, _ = modifier_random(text, 1)
print("1c. Date injection at RANDOM positions:")
pretty_print(modified_text, highlight_dates)
print("-" * 50)


def example_2_file_based_injection():
"""Example 2: Inject tokens from files (countries, colors)."""
print_section("Example 2: File-Based Token Injection")

# Example 2a: Country injection
country_modifier = ItemInjection.from_file(
file_path="examples/data/countries.txt",
location="random",
token_proportion=0.3,
seed=42,
)

country_highlighter = highlight_from_file("examples/data/countries.txt")
text = "International trade agreements benefit global economic stability."
modified_text, _ = country_modifier(text, 1)

print("2a. Country injection:")
pretty_print(modified_text, country_highlighter)
print("-" * 50)

# Example 2b: Color injection
color_modifier = ItemInjection.from_file(
file_path="examples/data/colors.txt",
location="random",
token_proportion=1,
seed=42,
)

color_highlighter = highlight_from_file("examples/data/colors.txt")
text = "The sunset painted the sky beautifully."
modified_text, _ = color_modifier(text, 1)

print("2b. Color injection:")
pretty_print(modified_text, color_highlighter)
print("-" * 50)

# Example 2c: Custom word list
custom_modifier = ItemInjection.from_list(
items=["URGENT", "BREAKING", "EXCLUSIVE", "ALERT"],
location="random",
token_proportion=1,
seed=42,
)

custom_highlighter = highlight_from_list(
["URGENT", "BREAKING", "EXCLUSIVE", "ALERT"]
)
text = "Weather forecast predicts rain tomorrow."
modified_text, _ = custom_modifier(text, 1)

print("2c. Custom urgent words:")
pretty_print(modified_text, custom_highlighter)
print("-" * 50)


def example_3_html_injection():
"""Example 3: HTML tag injection with different strategies."""
print_section("Example 3: HTML Tag Injection")
html_highlighter = highlight_html("examples/data/html_tags.txt")
text = "This is an important announcement for all users."

# Example 3a: Single HTML tag at beginning
begin_modifier = HTMLInjection.from_file(
file_path="examples/data/html_tags.txt", location="beginning", seed=42
)

modified_text, _ = begin_modifier(text, 1)
print("3a. Beginning single HTML tag injection:")
pretty_print(modified_text, html_highlighter)
print("-" * 50)

# Example 3b: Single HTML tag at random location
random_modifier = HTMLInjection.from_file(
file_path="examples/data/html_tags.txt", location="random", seed=43
)

modified_text, _ = random_modifier(text, 1)
print("3b. Random single HTML tag injection:")
pretty_print(modified_text, html_highlighter)
print("-" * 50)

# Example 3c: Single HTML tag at end
end_modifier = HTMLInjection.from_file(
file_path="examples/data/html_tags.txt", location="end", seed=44
)

modified_text, _ = end_modifier(text, 1)
print("3c. End single HTML tag injection:")
pretty_print(modified_text, html_highlighter)
print("-" * 50)

# Example 3d: Multiple HTML tags at random locations
multi_random_modifier = HTMLInjection.from_file(
file_path="examples/data/html_tags.txt",
location="random",
token_proportion=0.5,
seed=45,
)

modified_text, _ = multi_random_modifier(text, 1)
print("3d. Multiple random HTML tag injection:")
pretty_print(modified_text, html_highlighter)
print("-" * 50)


def example_4_multiple_injections():
"""Example 4: Multiple different injection types combined."""
print_section("Example 4: Multiple Injection Types Combined")

# Date at beginning
date_modifier = ItemInjection.from_function(
SpuriousDateGenerator(year_range=(2020, 2024), seed=42),
location="beginning",
token_proportion=0,
)

# Country in middle
country_modifier = ItemInjection.from_file(
file_path="examples/data/countries.txt",
location="random",
token_proportion=0,
seed=43,
)

# Color at end
color_modifier = ItemInjection.from_file(
file_path="examples/data/colors.txt",
location="end",
token_proportion=0,
seed=44,
)

# Combine all
multi_modifier = CompositeModifier(
[date_modifier, country_modifier, color_modifier]
)

text = "Economic analysis shows promising trends in renewable energy sectors."
modified_text, _ = multi_modifier(text, 1)
print("4. Multiple injection types:")
print(modified_text)
print("-" * 50)


def example_5_token_density_comparison():
"""Example 5: Compare different token proportion levels."""
print_section("Example 5: Token Proportion Comparison")

text = "Artificial intelligence and machine learning technologies are transforming industries."
highlighter = highlight_dates

token_proportions = [0, 0.3, 0.5, 0.8, 1.0] # 0 injects a single token

for density in token_proportions:
modifier = ItemInjection.from_function(
SpuriousDateGenerator(year_range=(2020, 2024), seed=42),
location="random",
token_proportion=density,
seed=42,
)

modified_text, _ = modifier(text, 1)
print(f"\nToken proportion {density}:")
pretty_print(modified_text, highlighter)


def example_6_dataset_simulation():
"""Example 6: Simulate dataset-level spurious correlations."""
print_section(
"Example 6: Dataset-Level Spurious Correlation Simulation using spurious_transform"
)

# Load IMDB dataset
dataset = load_dataset("imdb", split="train") # Load full training dataset

# Create date modifier
date_modifier = ItemInjection.from_function(
SpuriousDateGenerator(year_range=(2020, 2024), seed=42),
location="random",
token_proportion=0.1,
seed=42,
)

print("Simulating spurious correlation: Add dates to positive reviews only\n")

# Apply spurious transformation
modified_dataset = spurious_transform(
label_to_modify=1, # Target positive reviews
dataset=dataset,
modifier=date_modifier,
text_proportion=1.0, # Apply to all positive reviews
seed=42,
)

# Print examples using pretty_print_dataset
print("Positive reviews (with injected dates):")
pretty_print_dataset(modified_dataset, n=3, highlight_func=highlight_dates, label=1)

print("\nNegative reviews (original):")
pretty_print_dataset(modified_dataset, n=3, highlight_func=highlight_dates, label=0)


def main():
"""Run all examples demonstrating library capabilities."""
example_1_basic_date_injection()
example_2_file_based_injection()
example_3_html_injection()
example_4_multiple_injections()
example_5_token_density_comparison()
example_6_dataset_simulation()


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ datasets = [
"datasets", # HuggingFace datasets
"pyarrow>=15.0.0", # Required for datasets compatibility
"minari[hdf5]>=0.5.3", # Reinforcement learning datasets
"termcolor", # Visualizing spurious correlations
]

# Additional utilities
Expand Down
Binary file not shown.
23 changes: 23 additions & 0 deletions stable_pretraining/data/spurious_corr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""spurious_corr package.

This package provides tools to apply and test the effect of various transformations
(such as injecting spurious text) on datasets for research and testing purposes.
It includes functionality for text transformations, various generators for spurious
text, and utilities for printing and highlighting text.
"""

from ..transforms import (
Modifier as Modifier,
CompositeModifier as CompositeModifier,
ItemInjection as ItemInjection,
HTMLInjection as HTMLInjection,
)
from .transform import spurious_transform as spurious_transform
from .generators import SpuriousDateGenerator as SpuriousDateGenerator
from .utils import (
pretty_print as pretty_print,
pretty_print_dataset as pretty_print_dataset,
highlight_dates as highlight_dates,
highlight_from_file as highlight_from_file,
highlight_html as highlight_html,
)
Loading