Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions pytest_split_tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,23 @@
from _pytest.config import create_terminal_writer
import pytest

def get_group_size_and_start(total_items, total_groups, group_id):
"""Calculate group size and start index."""
base_size = total_items // total_groups
rem = total_items % total_groups

def get_group_size(total_items, total_groups):
"""Return the group size."""
return int(math.ceil(float(total_items) / total_groups))
start = base_size * (group_id - 1) + min(group_id - 1, rem)
size = base_size + 1 if group_id <= rem else base_size

return (start, size)

def get_group(items, group_size, group_id):
def get_group(items, total_groups, group_id):
"""Get the items from the passed in group based on group size."""
start = group_size * (group_id - 1)
end = start + group_size

if start >= len(items) or start < 0:
if not 0 < group_id <= total_groups:
raise ValueError("Invalid test-group argument")

return items[start:end]
start, size = get_group_size_and_start(len(items), total_groups, group_id)
return items[start:start+size]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could replace all of this with:

return items[group_id:len(items):group_size]

This is the three-argument slice-operator: [start:end:step].



def pytest_addoption(parser):
Expand All @@ -40,8 +42,8 @@ def pytest_collection_modifyitems(session, config, items):
yield
group_count = config.getoption('test-group-count')
group_id = config.getoption('test-group')
seed = config.getoption('random-seed', False)
prescheduled_path = config.getoption('prescheduled', None)
seed = config.getoption('random-seed')
prescheduled_path = config.getoption('prescheduled')

if not group_count or not group_id:
return
Expand Down Expand Up @@ -70,14 +72,13 @@ def pytest_collection_modifyitems(session, config, items):
if test_name in test_dict]
unscheduled_tests = [item for item in items if item not in all_prescheduled_tests]

if seed is not False:
if seed is not None:
seeded = Random(seed)
seeded.shuffle(unscheduled_tests)

total_unscheduled_items = len(unscheduled_tests)

group_size = get_group_size(total_unscheduled_items, group_count)
tests_in_group = get_group(unscheduled_tests, group_size, group_id)
tests_in_group = get_group(unscheduled_tests, group_count, group_id)
items[:] = tests_in_group + prescheduled_tests

items.sort(key=original_order.__getitem__)
Expand Down
20 changes: 10 additions & 10 deletions tests/test_groups.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,33 @@
import pytest

from pytest_split_tests import get_group, get_group_size
from pytest_split_tests import get_group, get_group_size_and_start


def test_group_size_computed_correctly_for_even_group():
expected = 8
actual = get_group_size(32, 4) # 32 total tests; 4 groups
def test_group_params_computed_correctly_for_even_group():
expected = [(0, 8), (8, 8), (16, 8), (24, 8)]
actual = [get_group_size_and_start(32, 4, group_id) for group_id in range(1, 5)] # 32 total tests; 4 groups

assert expected == actual


def test_group_size_computed_correctly_for_odd_group():
expected = 8
actual = get_group_size(31, 4) # 31 total tests; 4 groups
expected = [(0, 8), (8, 8), (16, 8), (24, 7)]
actual = [get_group_size_and_start(31, 4, group_id) for group_id in range(1, 5)] # 32 total tests; 4 groups

assert expected == actual


def test_group_is_the_proper_size():
items = [str(i) for i in range(32)]
group = get_group(items, 8, 1)
group = get_group(items, 4, 1)

assert len(group) == 8


def test_all_groups_together_form_original_set_of_tests():
items = [str(i) for i in range(32)]

groups = [get_group(items, 8, i) for i in range(1, 5)]
groups = [get_group(items, 4, i) for i in range(1, 5)]

combined = []
for group in groups:
Expand All @@ -40,11 +40,11 @@ def test_group_that_is_too_high_raises_value_error():
items = [str(i) for i in range(32)]

with pytest.raises(ValueError):
get_group(items, 8, 5)
get_group(items, 4, 5)


def test_group_that_is_too_low_raises_value_error():
items = [str(i) for i in range(32)]

with pytest.raises(ValueError):
get_group(items, 8, 0)
get_group(items, 4, 0)