Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions common/src/buttercup/common/corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,21 @@ def __init__(self, wdir: str, name: str, copy_corpus_max_size: int | None = None
def basename(self) -> str:
return os.path.basename(self.path)

def copy_file(self, src_file: str) -> str:
def copy_file(self, src_file: str, only_local: bool = False) -> str:
with open(src_file, "rb") as f:
nm = hash_file(f)
dst = os.path.join(self.path, nm)
dst_remote = os.path.join(self.remote_path, nm)
os.makedirs(self.remote_path, exist_ok=True)
# Make the file available both node-local and remote
# Copy to local corpus
shutil.copy(src_file, dst)
shutil.copy(dst, dst_remote)
if not only_local:
dst_remote = os.path.join(self.remote_path, nm)
os.makedirs(self.remote_path, exist_ok=True)
# Copy to remote corpus
shutil.copy(dst, dst_remote)
return dst

def copy_corpus(self, src_dir: str) -> list[str]:
"""Copy files from src_dir to local corpus only."""
files = []
for file in os.listdir(src_dir):
file_path = os.path.join(src_dir, file)
Expand All @@ -60,7 +63,7 @@ def copy_corpus(self, src_dir: str) -> list[str]:
self.copy_corpus_max_size,
)
continue
files.append(self.copy_file(file_path))
files.append(self.copy_file(file_path, only_local=True))
return files

def local_corpus_size(self) -> int:
Expand Down
75 changes: 75 additions & 0 deletions common/tests/test_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,78 @@ def test_input_dir_copy_corpus_all_files_too_large(temp_dir, mock_node_local):
# Should return empty list
assert copied_files == []
assert input_dir.local_corpus_count() == 0


def test_copy_corpus_only_local(temp_dir):
"""Test that copy_corpus copies only to node-local (not remote)."""
remote_path = os.path.join(temp_dir, "remote")
with patch("buttercup.common.node_local.remote_path", return_value=remote_path):
input_dir = InputDir(temp_dir, "test_corpus")

src_dir = os.path.join(temp_dir, "src_corpus")
os.makedirs(src_dir, exist_ok=True)

# Create a test file
file_path = os.path.join(src_dir, "test_file")
with open(file_path, "wb") as f:
f.write(b"test content")

copied_files = input_dir.copy_corpus(src_dir)

# File should exist locally
assert len(copied_files) == 1
assert os.path.exists(copied_files[0])

# Remote file should not exist
remote_file = os.path.join(remote_path, os.path.basename(copied_files[0]))
assert not os.path.exists(remote_file)


def test_copy_file_only_local(temp_dir):
"""Test that copy_file with only_local=True skips remote copy."""
remote_path = os.path.join(temp_dir, "remote")
with patch("buttercup.common.node_local.remote_path", return_value=remote_path):
input_dir = InputDir(temp_dir, "test_corpus")

src_dir = os.path.join(temp_dir, "src_corpus")
os.makedirs(src_dir, exist_ok=True)

# Create a test file
file_path = os.path.join(src_dir, "test_file")
with open(file_path, "wb") as f:
f.write(b"test content")

# Copy file with only_local=True
dst = input_dir.copy_file(file_path, only_local=True)

# File should exist locally
assert os.path.exists(dst)

# Remote file should not exist
remote_file = os.path.join(remote_path, os.path.basename(dst))
assert not os.path.exists(remote_file)


def test_copy_file_with_remote(temp_dir):
"""Test that copy_file with only_local=False copies to both local and remote."""
remote_path = os.path.join(temp_dir, "remote")
with patch("buttercup.common.node_local.remote_path", return_value=remote_path):
input_dir = InputDir(temp_dir, "test_corpus")

src_dir = os.path.join(temp_dir, "src_corpus")
os.makedirs(src_dir, exist_ok=True)

# Create a test file
file_path = os.path.join(src_dir, "test_file")
with open(file_path, "wb") as f:
f.write(b"test content")

# Copy file with only_local=False (explicit)
dst = input_dir.copy_file(file_path, only_local=False)

# File should exist locally
assert os.path.exists(dst)

# Same file should exist in remote
remote_file = os.path.join(remote_path, os.path.basename(dst))
assert os.path.exists(remote_file)