Skip to content

Commit 6742689

Browse files
icfaustahuber21
andauthored
[testing, CI] fix coverage statistics issue caused by test_common.py tracer patching (#2237)
* attempt 1 at fixing the issue * probe error' * fix basic mistake * sneaky skips * Update test_common.py * another mistake * Revert "another mistake" This reverts commit 5c602da. * switch to fixture * fix and optimize * attempt with multiprocessing * fix errors * add comments * add missing text * Update test_common.py * Apply suggestions from code review Co-authored-by: Andreas Huber <[email protected]> * Update sklearnex/tests/test_common.py Co-authored-by: Andreas Huber <[email protected]> * Update test_common.py * switch based on recommendation * Update test_common.py * revert to test * Introduce infrastructure to run on main process * Update test_common.py * Update test_common.py * Update test_common.py * Update test_common.py * Update test_common.py * Update test_common.py * Update test_common.py * Update test_common.py * Update test_common.py * Update test_common.py * Update test_common.py --------- Co-authored-by: Andreas Huber <[email protected]>
1 parent c76273d commit 6742689

File tree

1 file changed

+130
-29
lines changed

1 file changed

+130
-29
lines changed

sklearnex/tests/test_common.py

Lines changed: 130 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,15 @@
1515
# ==============================================================================
1616

1717
import importlib.util
18+
import io
1819
import os
1920
import pathlib
2021
import pkgutil
2122
import re
2223
import sys
2324
import trace
25+
from contextlib import redirect_stdout
26+
from multiprocessing import Pipe, Process, get_context
2427

2528
import pytest
2629
from sklearn.utils import all_estimators
@@ -225,23 +228,137 @@ def _commonpath(inp):
225228
_TRACE_BLOCK_LIST = _whitelist_to_blacklist()
226229

227230

231+
def sklearnex_trace(estimator_name, method_name):
232+
"""Generate a trace of all function calls in calling estimator.method.
233+
234+
Parameters
235+
----------
236+
estimator_name : str
237+
name of estimator which is a key from PATCHED_MODELS or SPECIAL_INSTANCES
238+
239+
method_name : str
240+
name of estimator method which is to be traced and stored
241+
242+
Returns
243+
-------
244+
text: str
245+
Returns a string output (captured stdout of a python Trace call). It is a
246+
modified version to be more informative, completed by a monkeypatching
247+
of trace._modname.
248+
"""
249+
# get estimator
250+
est = (
251+
PATCHED_MODELS[estimator_name]()
252+
if estimator_name in PATCHED_MODELS
253+
else SPECIAL_INSTANCES[estimator_name]
254+
)
255+
256+
# get dataset
257+
X, y = gen_dataset(est)[0]
258+
# fit dataset if method does not contain 'fit'
259+
if "fit" not in method_name:
260+
est.fit(X, y)
261+
262+
# monkeypatch new modname for clearer info
263+
orig_modname = trace._modname
264+
try:
265+
# initialize tracer to have a more verbose module naming
266+
# this impacts ignoremods, but it is not used.
267+
trace._modname = _fullpath
268+
tracer = trace.Trace(
269+
count=0,
270+
trace=1,
271+
ignoredirs=_TRACE_BLOCK_LIST,
272+
)
273+
# call trace on method with dataset
274+
f = io.StringIO()
275+
with redirect_stdout(f):
276+
tracer.runfunc(call_method, est, method_name, X, y)
277+
return f.getvalue()
278+
finally:
279+
trace._modname = orig_modname
280+
281+
282+
def _trace_daemon(pipe):
283+
"""function interface for the other process. Information
284+
exchanged using a multiprocess.Pipe"""
285+
# a sent value with inherent conversion to False will break
286+
# the while loop and complete the function
287+
while key := pipe.recv():
288+
try:
289+
text = sklearnex_trace(*key)
290+
except:
291+
# catch all exceptions and pass back,
292+
# this way the process still runs
293+
text = ""
294+
finally:
295+
pipe.send(text)
296+
297+
298+
class _FakePipe:
299+
"""Minimalistic representation of a multiprocessing.Pipe for test development.
300+
This allows for running sklearnex_trace in the parent process"""
301+
302+
_text = ""
303+
304+
def send(self, key):
305+
self._text = sklearnex_trace(*key)
306+
307+
def recv(self):
308+
return self._text
309+
310+
311+
@pytest.fixture(scope="module")
312+
def isolated_trace():
313+
"""Generates a separate python process for isolated sklearnex traces.
314+
315+
It is a module scope fixture due to the overhead of importing all the
316+
various dependencies and is done once before all the various tests.
317+
Each test will first check a cached value, if not existent it will have
318+
the waiting child process generate the trace and return the text for
319+
caching on its behalf. The isolated process is stopped at test teardown.
320+
321+
Yields
322+
-------
323+
pipe_parent: multiprocessing.Connection
324+
one end of a duplex pipe to be used by other pytest fixtures for
325+
communicating with the special isolated tracing python instance
326+
for sklearnex estimators.
327+
"""
328+
# yield _FakePipe()
329+
try:
330+
# force use of 'spawn' to guarantee a clean python environment
331+
# from possible coverage arc tracing
332+
ctx = get_context("spawn")
333+
pipe_parent, pipe_child = ctx.Pipe()
334+
p = ctx.Process(target=_trace_daemon, args=(pipe_child,), daemon=True)
335+
p.start()
336+
yield pipe_parent
337+
finally:
338+
# guarantee closing of the process via a try-catch-finally
339+
# passing False terminates _trace_daemon's loop
340+
pipe_parent.send(False)
341+
pipe_parent.close()
342+
pipe_child.close()
343+
p.join()
344+
p.close()
345+
346+
228347
@pytest.fixture
229-
def estimator_trace(estimator, method, cache, capsys, monkeypatch):
230-
"""Generate a trace of all function calls in calling estimator.method with cache.
348+
def estimator_trace(estimator, method, cache, isolated_trace):
349+
"""Create cache of all function calls in calling estimator.method.
231350
232351
Parameters
233352
----------
234353
estimator : str
235-
name of estimator which is a key from PATCHED_MODELS or
354+
name of estimator which is a key from PATCHED_MODELS or SPECIAL_INSTANCES
236355
237356
method : str
238357
name of estimator method which is to be traced and stored
239358
240359
cache: pytest.fixture (standard)
241360
242-
capsys: pytest.fixture (standard)
243-
244-
monkeypatch: pytest.fixture (standard)
361+
isolated_trace: pytest.fixture (test_common.py)
245362
246363
Returns
247364
-------
@@ -256,31 +373,15 @@ def estimator_trace(estimator, method, cache, capsys, monkeypatch):
256373
key = "-".join((str(estimator), method))
257374
flag = cache.get("key", "") != key
258375
if flag:
259-
# get estimator
260-
try:
261-
est = PATCHED_MODELS[estimator]()
262-
except KeyError:
263-
est = SPECIAL_INSTANCES[estimator]
264-
265-
# get dataset
266-
X, y = gen_dataset(est)[0]
267-
# fit dataset if method does not contain 'fit'
268-
if "fit" not in method:
269-
est.fit(X, y)
270376

271-
# initialize tracer to have a more verbose module naming
272-
# this impacts ignoremods, but it is not used.
273-
monkeypatch.setattr(trace, "_modname", _fullpath)
274-
tracer = trace.Trace(
275-
count=0,
276-
trace=1,
277-
ignoredirs=_TRACE_BLOCK_LIST,
278-
)
279-
# call trace on method with dataset
280-
tracer.runfunc(call_method, est, method, X, y)
377+
isolated_trace.send((estimator, method))
378+
text = isolated_trace.recv()
379+
# if tracing does not function in isolated_trace, run it in parent process and error
380+
if text == "":
381+
sklearnex_trace(estimator, method)
382+
# guarantee failure if intermittent
383+
assert text, f"sklearnex_trace failure for {estimator}.{method}"
281384

282-
# collect trace for analysis
283-
text = capsys.readouterr().out
284385
for modulename, file in _TRACE_ALLOW_DICT.items():
285386
text = text.replace(file, modulename)
286387
regex_func = (

0 commit comments

Comments
 (0)