Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,4 +160,10 @@ cython_debug/
#.idea/

pythia*
opt*
opt*

# Claude Code
.claude/*

# Additional testing artifacts
.benchmarks/
2,409 changes: 2,409 additions & 0 deletions poetry.lock

Large diffs are not rendered by default.

79 changes: 79 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
[tool.poetry]
name = "speculative-sampling"
version = "0.1.0"
description = "Speculative Sampling implementation for text generation"
authors = ["Your Name <you@example.com>"]
readme = "README.md"
packages = [{include = "*.py"}]

[tool.poetry.dependencies]
python = "^3.8"
torch = "^2.0.0"
transformers = "^4.30.0"
tqdm = "^4.65.0"

[tool.poetry.group.test.dependencies]
pytest = "^7.4.0"
pytest-cov = "^4.1.0"
pytest-mock = "^3.11.1"

[tool.poetry.scripts]
test = "pytest:main"
tests = "pytest:main"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = ["test_*.py", "*_test.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
addopts = [
"--strict-markers",
"--strict-config",
"--verbose",
"--cov=.",
"--cov-report=html:htmlcov",
"--cov-report=xml:coverage.xml",
"--cov-report=term-missing",
"--cov-fail-under=80"
]
markers = [
"unit: marks tests as unit tests (deselect with '-m \"not unit\"')",
"integration: marks tests as integration tests (deselect with '-m \"not integration\"')",
"slow: marks tests as slow (deselect with '-m \"not slow\"')"
]

[tool.coverage.run]
source = ["."]
omit = [
"tests/*",
"htmlcov/*",
".pytest_cache/*",
"setup.py",
"conftest.py",
"*/__pycache__/*",
"*/site-packages/*"
]

[tool.coverage.report]
exclude_lines = [
"pragma: no cover",
"def __repr__",
"if self.debug:",
"if settings.DEBUG",
"raise AssertionError",
"raise NotImplementedError",
"if 0:",
"if __name__ == .__main__.:",
"class .*\\bProtocol\\):",
"@(abc\\.)?abstractmethod"
]
precision = 2
show_missing = true
skip_covered = false

[tool.coverage.html]
directory = "htmlcov"
Empty file added tests/__init__.py
Empty file.
116 changes: 116 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""
Shared pytest fixtures for the speculative sampling test suite.
"""

import pytest
import tempfile
import torch
from pathlib import Path
from unittest.mock import MagicMock, patch


@pytest.fixture
def temp_dir():
"""Provides a temporary directory for tests."""
with tempfile.TemporaryDirectory() as tmp_dir:
yield Path(tmp_dir)


@pytest.fixture
def mock_device():
"""Provides a mock device for testing without requiring GPU/CPU specifics."""
return "cpu"


@pytest.fixture
def mock_torch_tensor():
"""Provides a mock torch tensor for testing."""
return torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.long)


@pytest.fixture
def sample_prompt():
"""Provides a sample text prompt for testing."""
return "The quick brown fox"


@pytest.fixture
def sample_tokens():
"""Provides sample tokenized input for testing."""
return torch.tensor([[101, 2003, 4248, 2829, 4419]], dtype=torch.long)


@pytest.fixture
def mock_tokenizer():
"""Provides a mock tokenizer for testing."""
tokenizer = MagicMock()
tokenizer.encode.return_value = [101, 2003, 4248, 2829, 4419]
tokenizer.decode.return_value = "The quick brown fox jumps over"
tokenizer.return_tensors = "pt"
return tokenizer


@pytest.fixture
def mock_model():
"""Provides a mock language model for testing."""
model = MagicMock()

# Mock the logits output
mock_logits = torch.randn(1, 5, 50257) # batch_size=1, seq_len=5, vocab_size=50257
mock_output = MagicMock()
mock_output.logits = mock_logits
model.return_value = mock_output

return model


@pytest.fixture
def mock_transformers():
"""Provides mocked transformers components to avoid loading real models."""
with patch('transformers.AutoTokenizer') as mock_tokenizer_class, \
patch('transformers.AutoModelForCausalLM') as mock_model_class:

# Configure the mock tokenizer
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [101, 2003, 4248, 2829, 4419]
mock_tokenizer.decode.return_value = "Mocked decoded text"
mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer

# Configure the mock model
mock_model = MagicMock()
mock_logits = torch.randn(1, 5, 50257)
mock_output = MagicMock()
mock_output.logits = mock_logits
mock_model.return_value = mock_output
mock_model_class.from_pretrained.return_value = mock_model

yield {
'tokenizer': mock_tokenizer,
'model': mock_model,
'tokenizer_class': mock_tokenizer_class,
'model_class': mock_model_class
}


@pytest.fixture
def sampling_config():
"""Provides common sampling configuration for tests."""
return {
'target_len': 10,
'temperature': 1.0,
'max_new_tokens': 5
}


@pytest.fixture
def mock_cuda():
"""Mock CUDA availability for testing."""
with patch('torch.cuda.is_available') as mock_cuda_available:
mock_cuda_available.return_value = False # Default to CPU for tests
yield mock_cuda_available


@pytest.fixture
def suppress_output(capsys):
"""Fixture to capture and suppress stdout/stderr during tests."""
return capsys
Empty file added tests/integration/__init__.py
Empty file.
101 changes: 101 additions & 0 deletions tests/test_infrastructure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""
Infrastructure validation tests to verify the testing setup works correctly.
"""

import pytest
import torch
from pathlib import Path


class TestInfrastructure:
"""Test suite to validate the testing infrastructure setup."""

@pytest.mark.unit
def test_fixtures_available(self, temp_dir, mock_device, sample_prompt):
"""Test that all basic fixtures are working."""
assert isinstance(temp_dir, Path)
assert temp_dir.exists()
assert mock_device == "cpu"
assert isinstance(sample_prompt, str)
assert len(sample_prompt) > 0

@pytest.mark.unit
def test_torch_integration(self, mock_torch_tensor):
"""Test that PyTorch integration works in the test environment."""
assert isinstance(mock_torch_tensor, torch.Tensor)
assert mock_torch_tensor.shape == (1, 5)
assert mock_torch_tensor.dtype == torch.long

@pytest.mark.unit
def test_mock_components(self, mock_tokenizer, mock_model):
"""Test that mock components are properly configured."""
# Test mock tokenizer
assert mock_tokenizer.encode("test") == [101, 2003, 4248, 2829, 4419]
assert "brown fox" in mock_tokenizer.decode([1, 2, 3])

# Test mock model
mock_output = mock_model()
assert hasattr(mock_output, 'logits')
assert isinstance(mock_output.logits, torch.Tensor)

@pytest.mark.unit
def test_sampling_config(self, sampling_config):
"""Test that sampling configuration fixture works."""
assert 'target_len' in sampling_config
assert 'temperature' in sampling_config
assert 'max_new_tokens' in sampling_config
assert sampling_config['target_len'] == 10
assert sampling_config['temperature'] == 1.0
assert sampling_config['max_new_tokens'] == 5

@pytest.mark.integration
def test_full_mock_pipeline(self, mock_transformers, sampling_config):
"""Test that the full mocked pipeline works."""
tokenizer = mock_transformers['tokenizer']
model = mock_transformers['model']

# Test tokenization
tokens = tokenizer.encode("test input")
assert isinstance(tokens, list)
assert len(tokens) > 0

# Test model inference
model_output = model()
assert hasattr(model_output, 'logits')
assert isinstance(model_output.logits, torch.Tensor)

# Test decoding
decoded = tokenizer.decode([1, 2, 3])
assert isinstance(decoded, str)
assert len(decoded) > 0


class TestMarkers:
"""Test that pytest markers work correctly."""

@pytest.mark.unit
def test_unit_marker(self):
"""Test with unit marker."""
assert True

@pytest.mark.integration
def test_integration_marker(self):
"""Test with integration marker."""
assert True

@pytest.mark.slow
def test_slow_marker(self):
"""Test with slow marker."""
assert True


def test_basic_functionality():
"""Basic test to verify pytest is working."""
assert 1 + 1 == 2


def test_torch_basic():
"""Test basic PyTorch functionality."""
tensor = torch.tensor([1, 2, 3])
assert tensor.shape == (3,)
assert tensor.sum().item() == 6
Empty file added tests/unit/__init__.py
Empty file.