Skip to content

Commit 4cee4d7

Browse files
Clean up and reorganize pytest tests (#377)
* Replace api_wrapper with runner in test fixtures and configurations Co-authored-by: ss.shankar505 <[email protected]> * Refactor test fixtures and reorganize configuration in test files Co-authored-by: ss.shankar505 <[email protected]> --------- Co-authored-by: Cursor Agent <[email protected]>
1 parent 9a72a6b commit 4cee4d7

17 files changed

+363
-403
lines changed

tests/basic/test_basic_filter_split_gather.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from docetl.operations.utils import APIWrapper
88
from docetl.config_wrapper import ConfigWrapper
99
from dotenv import load_dotenv
10-
from tests.conftest import api_wrapper
10+
from tests.conftest import runner
1111

1212

1313
load_dotenv()
@@ -35,19 +35,19 @@ def filter_sample_data():
3535

3636

3737
def test_filter_operation(
38-
filter_config, default_model, max_threads, filter_sample_data, api_wrapper
38+
filter_config, default_model, max_threads, filter_sample_data, runner
3939
):
40-
operation = FilterOperation(api_wrapper, filter_config, default_model, max_threads)
40+
operation = FilterOperation(runner, filter_config, default_model, max_threads)
4141
results, cost = operation.execute(filter_sample_data)
4242

4343
assert len(results) < len(filter_sample_data)
4444
assert all(len(result["text"].split()) > 3 for result in results)
4545

4646

4747
def test_filter_operation_empty_input(
48-
filter_config, default_model, max_threads, api_wrapper
48+
filter_config, default_model, max_threads, runner
4949
):
50-
operation = FilterOperation(api_wrapper, filter_config, default_model, max_threads)
50+
operation = FilterOperation(runner, filter_config, default_model, max_threads)
5151
results, cost = operation.execute([])
5252

5353
assert len(results) == 0
@@ -96,10 +96,10 @@ def dict_unnest_sample_data():
9696

9797

9898
def test_dict_unnest_operation(
99-
dict_unnest_config, default_model, max_threads, dict_unnest_sample_data, api_wrapper
99+
dict_unnest_config, default_model, max_threads, dict_unnest_sample_data, runner
100100
):
101101
operation = UnnestOperation(
102-
api_wrapper, dict_unnest_config, default_model, max_threads
102+
runner, dict_unnest_config, default_model, max_threads
103103
)
104104
results, cost = operation.execute(dict_unnest_sample_data)
105105

@@ -118,10 +118,10 @@ def test_dict_unnest_operation(
118118

119119

120120
def test_dict_unnest_operation_empty_input(
121-
dict_unnest_config, default_model, max_threads, api_wrapper
121+
dict_unnest_config, default_model, max_threads, runner
122122
):
123123
operation = UnnestOperation(
124-
api_wrapper, dict_unnest_config, default_model, max_threads
124+
runner, dict_unnest_config, default_model, max_threads
125125
)
126126
results, cost = operation.execute([])
127127

@@ -130,9 +130,9 @@ def test_dict_unnest_operation_empty_input(
130130

131131

132132
def test_unnest_operation(
133-
unnest_config, default_model, max_threads, unnest_sample_data, api_wrapper
133+
unnest_config, default_model, max_threads, unnest_sample_data, runner
134134
):
135-
operation = UnnestOperation(api_wrapper, unnest_config, default_model, max_threads)
135+
operation = UnnestOperation(runner, unnest_config, default_model, max_threads)
136136
results, cost = operation.execute(unnest_sample_data)
137137

138138
assert len(results) == 6 # 3 + 2 + 1
@@ -141,9 +141,9 @@ def test_unnest_operation(
141141

142142

143143
def test_unnest_operation_empty_input(
144-
unnest_config, default_model, max_threads, api_wrapper
144+
unnest_config, default_model, max_threads, runner
145145
):
146-
operation = UnnestOperation(api_wrapper, unnest_config, default_model, max_threads)
146+
operation = UnnestOperation(runner, unnest_config, default_model, max_threads)
147147
results, cost = operation.execute([])
148148

149149
assert len(results) == 0
@@ -182,10 +182,10 @@ def right_data():
182182

183183

184184
def test_equijoin_operation(
185-
equijoin_config, default_model, max_threads, left_data, right_data, api_wrapper
185+
equijoin_config, default_model, max_threads, left_data, right_data, runner
186186
):
187187
operation = EquijoinOperation(
188-
api_wrapper, equijoin_config, default_model, max_threads
188+
runner, equijoin_config, default_model, max_threads
189189
)
190190
results, cost = operation.execute(left_data, right_data)
191191

@@ -194,10 +194,10 @@ def test_equijoin_operation(
194194

195195

196196
def test_equijoin_operation_empty_input(
197-
equijoin_config, default_model, max_threads, api_wrapper
197+
equijoin_config, default_model, max_threads, runner
198198
):
199199
operation = EquijoinOperation(
200-
api_wrapper, equijoin_config, default_model, max_threads
200+
runner, equijoin_config, default_model, max_threads
201201
)
202202
results, cost = operation.execute([], [])
203203

@@ -253,9 +253,9 @@ def sample_data():
253253

254254

255255
def test_split_operation(
256-
split_config, default_model, max_threads, sample_data, api_wrapper
256+
split_config, default_model, max_threads, sample_data, runner
257257
):
258-
operation = SplitOperation(api_wrapper, split_config, default_model, max_threads)
258+
operation = SplitOperation(runner, split_config, default_model, max_threads)
259259
results, cost = operation.execute(sample_data)
260260

261261
assert len(results) > len(sample_data)
@@ -287,14 +287,14 @@ def test_split_operation(
287287

288288

289289
def test_gather_operation(
290-
split_config, gather_config, default_model, max_threads, sample_data, api_wrapper
290+
split_config, gather_config, default_model, max_threads, sample_data, runner
291291
):
292292
# First, split the data
293-
split_op = SplitOperation(api_wrapper, split_config, default_model, max_threads)
293+
split_op = SplitOperation(runner, split_config, default_model, max_threads)
294294
split_results, _ = split_op.execute(sample_data)
295295

296296
# Now, gather the split results
297-
gather_op = GatherOperation(api_wrapper, gather_config, default_model, max_threads)
297+
gather_op = GatherOperation(runner, gather_config, default_model, max_threads)
298298
results, cost = gather_op.execute(split_results)
299299

300300
assert len(results) == len(split_results)
@@ -313,10 +313,10 @@ def test_gather_operation(
313313

314314

315315
def test_split_gather_combined(
316-
split_config, gather_config, default_model, max_threads, sample_data, api_wrapper
316+
split_config, gather_config, default_model, max_threads, sample_data, runner
317317
):
318-
split_op = SplitOperation(api_wrapper, split_config, default_model, max_threads)
319-
gather_op = GatherOperation(api_wrapper, gather_config, default_model, max_threads)
318+
split_op = SplitOperation(runner, split_config, default_model, max_threads)
319+
gather_op = GatherOperation(runner, gather_config, default_model, max_threads)
320320

321321
split_results, split_cost = split_op.execute(sample_data)
322322
gather_results, gather_cost = gather_op.execute(split_results)
@@ -334,10 +334,10 @@ def test_split_gather_combined(
334334

335335

336336
def test_split_gather_empty_input(
337-
split_config, gather_config, default_model, max_threads, api_wrapper
337+
split_config, gather_config, default_model, max_threads, runner
338338
):
339-
split_op = SplitOperation(api_wrapper, split_config, default_model, max_threads)
340-
gather_op = GatherOperation(api_wrapper, gather_config, default_model, max_threads)
339+
split_op = SplitOperation(runner, split_config, default_model, max_threads)
340+
gather_op = GatherOperation(runner, gather_config, default_model, max_threads)
341341

342342
split_results, split_cost = split_op.execute([])
343343
assert len(split_results) == 0

0 commit comments

Comments
 (0)