Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions pytest_split_tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,23 @@
from _pytest.config import create_terminal_writer
import pytest

def get_group_size_and_start(total_items, total_groups, group_id):
"""Calculate group size and start index."""
base_size = total_items // total_groups
rem = total_items % total_groups

def get_group_size(total_items, total_groups, group_id):
"""Return the group size."""
base = total_items // total_groups
return base + 1 if group_id <= total_items % total_groups else base
start = base_size * (group_id - 1) + min(group_id - 1, rem)
size = base_size + 1 if group_id <= rem else base_size

return (start, size)

def get_group(items, group_size, group_id):
def get_group(items, total_groups, group_id):
"""Get the items from the passed in group based on group size."""
start = group_size * (group_id - 1)
end = start + group_size
if not 0 < group_id <= total_groups:
raise ValueError("Invalid test-group argument")

return items[start:end]
start, size = get_group_size_and_start(len(items), total_groups, group_id)
return items[start:start+size]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could replace all of this with:

return items[group_id:len(items):group_size]

This is the three-argument slice-operator: [start:end:step].



def pytest_addoption(parser):
Expand All @@ -44,9 +48,6 @@ def pytest_collection_modifyitems(session, config, items):
if not group_count or not group_id:
return

if not 0 < group_id <= group_count:
raise ValueError("Invalid test-group argument")

test_dict = {item.name: item for item in items}
original_order = {item: index for index, item in enumerate(items)}

Expand Down Expand Up @@ -77,8 +78,7 @@ def pytest_collection_modifyitems(session, config, items):

total_unscheduled_items = len(unscheduled_tests)

group_size = get_group_size(total_unscheduled_items, group_count, group_id)
tests_in_group = get_group(unscheduled_tests, group_size, group_id)
tests_in_group = get_group(unscheduled_tests, group_count, group_id)
items[:] = tests_in_group + prescheduled_tests

items.sort(key=original_order.__getitem__)
Expand Down
20 changes: 10 additions & 10 deletions tests/test_groups.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,33 @@
import pytest

from pytest_split_tests import get_group, get_group_size
from pytest_split_tests import get_group, get_group_size_and_start


def test_group_size_computed_correctly_for_even_group():
expected = 8
actual = get_group_size(32, 4) # 32 total tests; 4 groups
def test_group_params_computed_correctly_for_even_group():
expected = [(0, 8), (8, 8), (16, 8), (24, 8)]
actual = [get_group_size_and_start(32, 4, group_id) for group_id in range(1, 5)] # 32 total tests; 4 groups

assert expected == actual


def test_group_size_computed_correctly_for_odd_group():
expected = 8
actual = get_group_size(31, 4) # 31 total tests; 4 groups
expected = [(0, 8), (8, 8), (16, 8), (24, 7)]
actual = [get_group_size_and_start(31, 4, group_id) for group_id in range(1, 5)] # 32 total tests; 4 groups

assert expected == actual


def test_group_is_the_proper_size():
items = [str(i) for i in range(32)]
group = get_group(items, 8, 1)
group = get_group(items, 4, 1)

assert len(group) == 8


def test_all_groups_together_form_original_set_of_tests():
items = [str(i) for i in range(32)]

groups = [get_group(items, 8, i) for i in range(1, 5)]
groups = [get_group(items, 4, i) for i in range(1, 5)]

combined = []
for group in groups:
Expand All @@ -40,11 +40,11 @@ def test_group_that_is_too_high_raises_value_error():
items = [str(i) for i in range(32)]

with pytest.raises(ValueError):
get_group(items, 8, 5)
get_group(items, 4, 5)


def test_group_that_is_too_low_raises_value_error():
items = [str(i) for i in range(32)]

with pytest.raises(ValueError):
get_group(items, 8, 0)
get_group(items, 4, 0)