77from docetl .operations .utils import APIWrapper
88from docetl .config_wrapper import ConfigWrapper
99from dotenv import load_dotenv
10- from tests .conftest import api_wrapper
10+ from tests .conftest import runner
1111
1212
1313load_dotenv ()
@@ -35,19 +35,19 @@ def filter_sample_data():
3535
3636
3737def test_filter_operation (
38- filter_config , default_model , max_threads , filter_sample_data , api_wrapper
38+ filter_config , default_model , max_threads , filter_sample_data , runner
3939):
40- operation = FilterOperation (api_wrapper , filter_config , default_model , max_threads )
40+ operation = FilterOperation (runner , filter_config , default_model , max_threads )
4141 results , cost = operation .execute (filter_sample_data )
4242
4343 assert len (results ) < len (filter_sample_data )
4444 assert all (len (result ["text" ].split ()) > 3 for result in results )
4545
4646
4747def test_filter_operation_empty_input (
48- filter_config , default_model , max_threads , api_wrapper
48+ filter_config , default_model , max_threads , runner
4949):
50- operation = FilterOperation (api_wrapper , filter_config , default_model , max_threads )
50+ operation = FilterOperation (runner , filter_config , default_model , max_threads )
5151 results , cost = operation .execute ([])
5252
5353 assert len (results ) == 0
@@ -96,10 +96,10 @@ def dict_unnest_sample_data():
9696
9797
9898def test_dict_unnest_operation (
99- dict_unnest_config , default_model , max_threads , dict_unnest_sample_data , api_wrapper
99+ dict_unnest_config , default_model , max_threads , dict_unnest_sample_data , runner
100100):
101101 operation = UnnestOperation (
102- api_wrapper , dict_unnest_config , default_model , max_threads
102+ runner , dict_unnest_config , default_model , max_threads
103103 )
104104 results , cost = operation .execute (dict_unnest_sample_data )
105105
@@ -118,10 +118,10 @@ def test_dict_unnest_operation(
118118
119119
120120def test_dict_unnest_operation_empty_input (
121- dict_unnest_config , default_model , max_threads , api_wrapper
121+ dict_unnest_config , default_model , max_threads , runner
122122):
123123 operation = UnnestOperation (
124- api_wrapper , dict_unnest_config , default_model , max_threads
124+ runner , dict_unnest_config , default_model , max_threads
125125 )
126126 results , cost = operation .execute ([])
127127
@@ -130,9 +130,9 @@ def test_dict_unnest_operation_empty_input(
130130
131131
132132def test_unnest_operation (
133- unnest_config , default_model , max_threads , unnest_sample_data , api_wrapper
133+ unnest_config , default_model , max_threads , unnest_sample_data , runner
134134):
135- operation = UnnestOperation (api_wrapper , unnest_config , default_model , max_threads )
135+ operation = UnnestOperation (runner , unnest_config , default_model , max_threads )
136136 results , cost = operation .execute (unnest_sample_data )
137137
138138 assert len (results ) == 6 # 3 + 2 + 1
@@ -141,9 +141,9 @@ def test_unnest_operation(
141141
142142
143143def test_unnest_operation_empty_input (
144- unnest_config , default_model , max_threads , api_wrapper
144+ unnest_config , default_model , max_threads , runner
145145):
146- operation = UnnestOperation (api_wrapper , unnest_config , default_model , max_threads )
146+ operation = UnnestOperation (runner , unnest_config , default_model , max_threads )
147147 results , cost = operation .execute ([])
148148
149149 assert len (results ) == 0
@@ -182,10 +182,10 @@ def right_data():
182182
183183
184184def test_equijoin_operation (
185- equijoin_config , default_model , max_threads , left_data , right_data , api_wrapper
185+ equijoin_config , default_model , max_threads , left_data , right_data , runner
186186):
187187 operation = EquijoinOperation (
188- api_wrapper , equijoin_config , default_model , max_threads
188+ runner , equijoin_config , default_model , max_threads
189189 )
190190 results , cost = operation .execute (left_data , right_data )
191191
@@ -194,10 +194,10 @@ def test_equijoin_operation(
194194
195195
196196def test_equijoin_operation_empty_input (
197- equijoin_config , default_model , max_threads , api_wrapper
197+ equijoin_config , default_model , max_threads , runner
198198):
199199 operation = EquijoinOperation (
200- api_wrapper , equijoin_config , default_model , max_threads
200+ runner , equijoin_config , default_model , max_threads
201201 )
202202 results , cost = operation .execute ([], [])
203203
@@ -253,9 +253,9 @@ def sample_data():
253253
254254
255255def test_split_operation (
256- split_config , default_model , max_threads , sample_data , api_wrapper
256+ split_config , default_model , max_threads , sample_data , runner
257257):
258- operation = SplitOperation (api_wrapper , split_config , default_model , max_threads )
258+ operation = SplitOperation (runner , split_config , default_model , max_threads )
259259 results , cost = operation .execute (sample_data )
260260
261261 assert len (results ) > len (sample_data )
@@ -287,14 +287,14 @@ def test_split_operation(
287287
288288
289289def test_gather_operation (
290- split_config , gather_config , default_model , max_threads , sample_data , api_wrapper
290+ split_config , gather_config , default_model , max_threads , sample_data , runner
291291):
292292 # First, split the data
293- split_op = SplitOperation (api_wrapper , split_config , default_model , max_threads )
293+ split_op = SplitOperation (runner , split_config , default_model , max_threads )
294294 split_results , _ = split_op .execute (sample_data )
295295
296296 # Now, gather the split results
297- gather_op = GatherOperation (api_wrapper , gather_config , default_model , max_threads )
297+ gather_op = GatherOperation (runner , gather_config , default_model , max_threads )
298298 results , cost = gather_op .execute (split_results )
299299
300300 assert len (results ) == len (split_results )
@@ -313,10 +313,10 @@ def test_gather_operation(
313313
314314
315315def test_split_gather_combined (
316- split_config , gather_config , default_model , max_threads , sample_data , api_wrapper
316+ split_config , gather_config , default_model , max_threads , sample_data , runner
317317):
318- split_op = SplitOperation (api_wrapper , split_config , default_model , max_threads )
319- gather_op = GatherOperation (api_wrapper , gather_config , default_model , max_threads )
318+ split_op = SplitOperation (runner , split_config , default_model , max_threads )
319+ gather_op = GatherOperation (runner , gather_config , default_model , max_threads )
320320
321321 split_results , split_cost = split_op .execute (sample_data )
322322 gather_results , gather_cost = gather_op .execute (split_results )
@@ -334,10 +334,10 @@ def test_split_gather_combined(
334334
335335
336336def test_split_gather_empty_input (
337- split_config , gather_config , default_model , max_threads , api_wrapper
337+ split_config , gather_config , default_model , max_threads , runner
338338):
339- split_op = SplitOperation (api_wrapper , split_config , default_model , max_threads )
340- gather_op = GatherOperation (api_wrapper , gather_config , default_model , max_threads )
339+ split_op = SplitOperation (runner , split_config , default_model , max_threads )
340+ gather_op = GatherOperation (runner , gather_config , default_model , max_threads )
341341
342342 split_results , split_cost = split_op .execute ([])
343343 assert len (split_results ) == 0
0 commit comments