Skip to content
Merged
Changes from 14 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
3a63d36
attempt 1 at fixing the issue
icfaust Jan 1, 2025
1346532
probe error'
icfaust Jan 1, 2025
7545008
fix basic mistake
icfaust Jan 1, 2025
8179a90
sneaky skips
icfaust Jan 1, 2025
244004c
Update test_common.py
icfaust Jan 2, 2025
5c602da
another mistake
icfaust Jan 2, 2025
8faac56
Revert "another mistake"
icfaust Jan 2, 2025
ff5febf
switch to fixture
icfaust Jan 2, 2025
a0aeba2
fix and optimize
icfaust Jan 2, 2025
0c56de6
attempt with multiprocessing
icfaust Jan 2, 2025
626315d
fix errors
icfaust Jan 2, 2025
4f2c233
add comments
icfaust Jan 2, 2025
bf1cb47
add missing text
icfaust Jan 2, 2025
20a66ca
Update test_common.py
icfaust Jan 2, 2025
d68b979
Apply suggestions from code review
icfaust Jan 13, 2025
9de1f51
Update sklearnex/tests/test_common.py
icfaust Jan 13, 2025
6eaec9a
Update test_common.py
icfaust Jan 13, 2025
c285913
switch based on recommendation
icfaust Jan 13, 2025
e88e32d
Merge branch 'uxlfoundation:main' into dev/trace_removal
icfaust Jan 13, 2025
17d6c89
Update test_common.py
icfaust Jan 13, 2025
01198f0
revert to test
icfaust Jan 13, 2025
a639a1a
Introduce infrastructure to run on main process
icfaust Jan 13, 2025
9ed8ac8
Update test_common.py
icfaust Jan 13, 2025
c53040f
Update test_common.py
icfaust Jan 13, 2025
8d28702
Update test_common.py
icfaust Jan 14, 2025
b00c969
Merge branch 'uxlfoundation:main' into dev/trace_removal
icfaust Jan 14, 2025
379b2dd
Update test_common.py
icfaust Jan 14, 2025
c8275c0
Update test_common.py
icfaust Jan 14, 2025
8569186
Update test_common.py
icfaust Jan 14, 2025
a74de05
Update test_common.py
icfaust Jan 14, 2025
335c179
Update test_common.py
icfaust Jan 15, 2025
56267b6
Update test_common.py
icfaust Jan 15, 2025
7bb1c6b
Update test_common.py
icfaust Jan 15, 2025
9696c63
Update test_common.py
icfaust Jan 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 110 additions & 29 deletions sklearnex/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@
# ==============================================================================

import importlib.util
import io
import os
import pathlib
import pkgutil
import re
import sys
import trace
from contextlib import redirect_stdout
from multiprocessing import Pipe, Process, get_context

import pytest
from sklearn.utils import all_estimators
Expand Down Expand Up @@ -225,23 +228,120 @@ def _commonpath(inp):
_TRACE_BLOCK_LIST = _whitelist_to_blacklist()


def sklearnex_trace(estimator, method):
"""Generate a trace of all function calls in calling estimator.method.

Parameters
----------
estimator : str
name of estimator which is a key from PATCHED_MODELS or SPECIAL_INSTANCES

method : str
name of estimator method which is to be traced and stored

Returns
-------
text: str
Returns a string output (captured stdout of a python Trace call). It is a
modified version to be more informative, completed by a monkeypatching
of trace._modname.
"""
# get estimator
try:
est = PATCHED_MODELS[estimator]()
except KeyError:
est = SPECIAL_INSTANCES[estimator]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
try:
est = PATCHED_MODELS[estimator]()
except KeyError:
est = SPECIAL_INSTANCES[estimator]
estimator = (SPECIAL_INSTANCES | PATCHED_MODELS)[estimator_name]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to revert this, because SPECIAL_INSTANCES is a special dictionary of estimators which uses sklearn's clone (in order to guarantee that there is no hysteresis between uses of the instance). And the patched models are classes. To be honest, the difference is tech debt that I introduced at the beginning of 2024, as I was trying to unify the centralized testing. Hindsight I would structure things like SPECIAL_INSTANCES.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then I'd go with PATCHED_MODELS.get(estimator_name, None) or SPECIAL_INSTANCES[estimator_name]. I don't want to waste 4 lines on something that doesn't contribute to the function logic.

Copy link
Contributor Author

@icfaust icfaust Jan 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Took me some time to figure this out, turns out the ensemble algorithms of sklearn break the suggestion construction.

from sklearn.ensemble import RandomForestRegressor
RandomForestRegressor() or 3

will yield:
AttributeError: 'RandomForestRegressor' object has no attribute 'estimators_'. Did you mean: 'estimator'?

which means I cannot use this in this case, its definitely doing something with the or operator and checking if its an iterable. Its not something on our side, but comes from sklearn conformance. It comes from sklearn for whatever reason defining a __len__ for ensemble estimators thats only valid after fitting.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for sending you down a rabbit hole. I would still prefer a different implementation because I think try/except is a bit of an overkill

estimator = PATCHED_MODELS[estimator_name] if estimator_name in PATCHED_MODELS else SPECIAL_INSTANCES[estimator_name]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no apologies necessary, it showed i hadnt handled a failure case properly and would lead to pytest hanging, which would have been a nightmare to debug. now it should error 'gracefully' in CI (testing it now)


# get dataset
X, y = gen_dataset(est)[0]
# fit dataset if method does not contain 'fit'
if "fit" not in method:
est.fit(X, y)

# monkeypatch new modname for clearer info
orig_modname = trace._modname
try:
# initialize tracer to have a more verbose module naming
# this impacts ignoremods, but it is not used.
trace._modname = _fullpath
tracer = trace.Trace(
count=0,
trace=1,
ignoredirs=_TRACE_BLOCK_LIST,
)
# call trace on method with dataset
f = io.StringIO()
with redirect_stdout(f):
tracer.runfunc(call_method, est, method, X, y)
return f.getvalue()
finally:
trace._modname = orig_modname


def _trace_daemon(pipe):
"""function interface for the other process. Information
exchanged using a multiprocess.Pipe"""
# a sent value with inherent conversion to False will break
# the while loop and complete the function
while key := pipe.recv():
text = ""
try:
estimator, method = key
text = sklearnex_trace(estimator, method)
finally:
pipe.send(text)


@pytest.fixture(scope="module")
def isolated_trace():
"""Generates a separate python process for isolated sklearnex traces.

It is a module scope fixture due to the overhead of importing all the
various dependencies and is done once before all the various tests.
Each test will first check a cached value, if not existent it will have
the waiting child process generate the trace and return the text for
caching on its behalf. The isolated process is stopped at test teardown.

Yields
-------
pipe_parent: multiprocess.Connection
one end of a duplex pipe to be used by other pytest fixtures for
communicating with the special isolated tracing python instance
for sklearnex estimators.
"""
try:
# force use of 'spawn' to guarantee a clean python environment
# from possible coverage arc tracing
ctx = get_context("spawn")
pipe_parent, pipe_child = ctx.Pipe()
p = ctx.Process(target=_trace_daemon, args=(pipe_child,), daemon=True)
p.start()
yield pipe_parent
finally:
# guarantee closing of the process via a try-catch-finally
# passing False terminates _trace_daemon's loop
pipe_parent.send(False)
pipe_parent.close()
pipe_child.close()
p.join()
p.close()


@pytest.fixture
def estimator_trace(estimator, method, cache, capsys, monkeypatch):
"""Generate a trace of all function calls in calling estimator.method with cache.
def estimator_trace(estimator, method, cache, isolated_trace):
"""Create cache of all function calls in calling estimator.method.

Parameters
----------
estimator : str
name of estimator which is a key from PATCHED_MODELS or
name of estimator which is a key from PATCHED_MODELS or SPECIAL_INSTANCES

method : str
name of estimator method which is to be traced and stored

cache: pytest.fixture (standard)

capsys: pytest.fixture (standard)

monkeypatch: pytest.fixture (standard)
isolated_trace: pytest.fixture (test_common.py)

Returns
-------
Expand All @@ -256,31 +356,12 @@ def estimator_trace(estimator, method, cache, capsys, monkeypatch):
key = "-".join((str(estimator), method))
flag = cache.get("key", "") != key
if flag:
# get estimator
try:
est = PATCHED_MODELS[estimator]()
except KeyError:
est = SPECIAL_INSTANCES[estimator]

# get dataset
X, y = gen_dataset(est)[0]
# fit dataset if method does not contain 'fit'
if "fit" not in method:
est.fit(X, y)

# initialize tracer to have a more verbose module naming
# this impacts ignoremods, but it is not used.
monkeypatch.setattr(trace, "_modname", _fullpath)
tracer = trace.Trace(
count=0,
trace=1,
ignoredirs=_TRACE_BLOCK_LIST,
)
# call trace on method with dataset
tracer.runfunc(call_method, est, method, X, y)
isolated_trace.send((estimator, method))
text = isolated_trace.recv()
# provide a minimal error in the case that the tracing doesn't function
assert text, f"trace_daemon failure for {estimator}.{method}"

# collect trace for analysis
text = capsys.readouterr().out
for modulename, file in _TRACE_ALLOW_DICT.items():
text = text.replace(file, modulename)
regex_func = (
Expand Down
Loading