From 48576c9601388df3248f855c53092f5df6689446 Mon Sep 17 00:00:00 2001 From: Tan An Nie <121005973+tanannie22@users.noreply.github.com> Date: Sat, 8 Mar 2025 00:00:03 +0800 Subject: [PATCH 1/5] Add rbf_kernel Function (#3) * chore(deps): update dependency jinja2 to v3.1.6 [security] (#2345) * feature: add RBF kernel implementation and integrate with dispatcher * integrate RBF kernel in sklearnex/metrics * add benchmark_rbf_kernel.ipynb * revert Jinja2 version change * remove unused RBF kernel mapping from dispatcher --------- Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> Co-authored-by: Andrie Dazlee, Nurin Miza Afiqah Signed-off-by: Andrie Dazlee, Nurin Miza Afiqah Signed-off-by: Mohamad, Siti Nurhanisah Signed-off-by: Tan, An Nie --- examples/notebooks/benchmark_rbf_kernel.ipynb | 853 ++++++++++++++++++ sklearnex/metrics/pairwise.py | 4 + sklearnex/metrics/tests/test_metrics.py | 13 + 3 files changed, 870 insertions(+) create mode 100644 examples/notebooks/benchmark_rbf_kernel.ipynb diff --git a/examples/notebooks/benchmark_rbf_kernel.ipynb b/examples/notebooks/benchmark_rbf_kernel.ipynb new file mode 100644 index 0000000000..c12c58a8ab --- /dev/null +++ b/examples/notebooks/benchmark_rbf_kernel.ipynb @@ -0,0 +1,853 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Benchmarking RBF Kernel Performance\n", + "\n", + "This notebook compares the execution time of the RBF kernel function in scikit-learn and sklearnex. The goal is to evaluate the performance improvements using the oneDAL-optimized implementation.\n", + "\n", + "### Methodology\n", + "- We generate random matrices of different row sizes (2, 5, 10, 100, 1000, 10000) with 3 columns.\n", + "- We measure the execution time of the RBF kernel function in scikit-learn and sklearnex.\n", + "- We compare the results for accuracy and compute the speedup factor.\n", + "\n", + "### Setup & Environment" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "### System Information ###\n", + "OS: Windows 10 (10.0.22631)\n", + "Windows Edition: Windows 11 Enterprise\n", + "Processor: Intel64 Family 6 Model 154 Stepping 3, GenuineIntel\n", + "CPU Cores: 12 (Physical), 16 (Logical)\n", + "RAM: 63.64 GB\n", + "Graphics Card: Intel(R) Iris(R) Xe Graphics\n", + "\n", + "### Library Versions ###\n", + "Python Version: 3.10.11\n", + "scikit-learn Version: 1.6.1\n", + "NumPy Version: 1.26.4\n", + "Pandas Version: 2.1.3\n" + ] + } + ], + "source": [ + "import platform\n", + "import psutil\n", + "import sklearn\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "print(\"### System Information ###\")\n", + "print(f\"OS: {platform.system()} {platform.release()} ({platform.version()})\")\n", + "print(\"Windows Edition: Windows 11 Enterprise\")\n", + "print(f\"Processor: {platform.processor()}\")\n", + "print(f\"CPU Cores: {psutil.cpu_count(logical=False)} (Physical), {psutil.cpu_count(logical=True)} (Logical)\")\n", + "print(f\"RAM: {round(psutil.virtual_memory().total / (1024**3), 2)} GB\")\n", + "print(\"Graphics Card: Intel(R) Iris(R) Xe Graphics\")\n", + "\n", + "print(\"\\n### Library Versions ###\")\n", + "print(f\"Python Version: {platform.python_version()}\")\n", + "print(f\"scikit-learn Version: {sklearn.__version__}\")\n", + "print(f\"NumPy Version: {np.__version__}\")\n", + "print(f\"Pandas Version: {pd.__version__}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from timeit import default_timer as timer\n", + "from sklearn.metrics.pairwise import rbf_kernel as sklearn_rbf_kernel\n", + "from sklearnex.metrics.pairwise import rbf_kernel as sklearnex_rbf_kernel\n", + "import numpy as np\n", + "\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\")\n", + "\n", + "# a list of different sizes of data to be tested\n", + "row_sizes = [2, 5, 10, 100, 1000, 10000]\n", + "col_size = 3\n", + "\n", + "results = []\n", + "\n", + "for row in row_sizes:\n", + " X = np.random.rand(row, col_size)\n", + " Y = np.random.rand(row, col_size)\n", + "\n", + " for i in range(10):\n", + " start = timer()\n", + " sklearn_result = sklearn_rbf_kernel(X, Y)\n", + " sklearn_time = timer() - start\n", + "\n", + " start = timer()\n", + " sklearnex_result = sklearnex_rbf_kernel(X, Y)\n", + " sklearnex_time = timer() - start\n", + "\n", + " results.append({\n", + " \"# of row\": row,\n", + " \"sklearn_result\": sklearn_result,\n", + " \"sklearnex_result\": sklearnex_result,\n", + " \"sklearn_time\": sklearn_time,\n", + " \"sklearnex_time\": sklearnex_time\n", + " })" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Results\n", + "\n", + "The results are stored in a dataframe for easier interpretation and analysis.\n", + "\n", + "- The `result_match` column verifies if both implementations produce identical results.\n", + "- The `speedup` column shows how much faster sklearnex is compared to scikit-learn." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
# of rowsklearn_timesklearnex_timeresult_matchspeedup
020.0005750.002887True0.20
120.0007020.000182True3.86
220.0005320.000176True3.03
320.0005100.000155True3.29
420.0005060.000150True3.38
520.0004770.000146True3.27
620.0004700.000144True3.28
720.0004710.000143True3.29
820.0004680.000143True3.27
920.0004870.000148True3.30
1050.0004770.000170True2.80
1150.0005490.000184True2.98
1250.0003280.000102True3.22
1350.0003070.000107True2.88
1450.0003100.000140True2.22
1550.0003530.000111True3.17
1650.0003450.000112True3.08
1750.0003410.000108True3.17
1850.0003580.000124True2.90
1950.0003880.000119True3.27
20100.0003980.000152True2.61
21100.0004300.000142True3.02
22100.0004170.000130True3.21
23100.0004100.000160True2.57
24100.0003980.000130True3.06
25100.0004020.000134True3.01
26100.0004290.000127True3.37
27100.0004050.000120True3.38
28100.0003770.000116True3.24
29100.0004000.000112True3.56
301000.0005960.000374True1.59
311000.0004540.000226True2.01
321000.0004030.000192True2.10
331000.0003740.000178True2.10
341000.0003530.000178True1.98
351000.0003420.000176True1.95
361000.0003240.000165True1.96
371000.0003180.000159True2.00
381000.0003140.000156True2.01
391000.0003160.000160True1.98
4010000.0133940.003922True3.42
4110000.0152610.002053True7.43
4210000.0121110.002061True5.88
4310000.0127270.001787True7.12
4410000.0113580.001635True6.94
4510000.0115450.001832True6.30
4610000.0117230.001660True7.06
4710000.0118720.001853True6.41
4810000.0117640.001686True6.98
4910000.0120140.001991True6.04
50100001.0189520.099710True10.22
51100001.0202430.107800True9.46
52100001.0325240.098206True10.51
53100001.0519170.109932True9.57
54100001.0580210.097128True10.89
55100001.0108340.100011True10.11
56100001.0106190.101155True9.99
57100001.3952460.239068True5.84
58100001.1493180.101077True11.37
59100000.9896590.098101True10.09
\n", + "
" + ], + "text/plain": [ + " # of row sklearn_time sklearnex_time result_match speedup\n", + "0 2 0.000575 0.002887 True 0.20\n", + "1 2 0.000702 0.000182 True 3.86\n", + "2 2 0.000532 0.000176 True 3.03\n", + "3 2 0.000510 0.000155 True 3.29\n", + "4 2 0.000506 0.000150 True 3.38\n", + "5 2 0.000477 0.000146 True 3.27\n", + "6 2 0.000470 0.000144 True 3.28\n", + "7 2 0.000471 0.000143 True 3.29\n", + "8 2 0.000468 0.000143 True 3.27\n", + "9 2 0.000487 0.000148 True 3.30\n", + "10 5 0.000477 0.000170 True 2.80\n", + "11 5 0.000549 0.000184 True 2.98\n", + "12 5 0.000328 0.000102 True 3.22\n", + "13 5 0.000307 0.000107 True 2.88\n", + "14 5 0.000310 0.000140 True 2.22\n", + "15 5 0.000353 0.000111 True 3.17\n", + "16 5 0.000345 0.000112 True 3.08\n", + "17 5 0.000341 0.000108 True 3.17\n", + "18 5 0.000358 0.000124 True 2.90\n", + "19 5 0.000388 0.000119 True 3.27\n", + "20 10 0.000398 0.000152 True 2.61\n", + "21 10 0.000430 0.000142 True 3.02\n", + "22 10 0.000417 0.000130 True 3.21\n", + "23 10 0.000410 0.000160 True 2.57\n", + "24 10 0.000398 0.000130 True 3.06\n", + "25 10 0.000402 0.000134 True 3.01\n", + "26 10 0.000429 0.000127 True 3.37\n", + "27 10 0.000405 0.000120 True 3.38\n", + "28 10 0.000377 0.000116 True 3.24\n", + "29 10 0.000400 0.000112 True 3.56\n", + "30 100 0.000596 0.000374 True 1.59\n", + "31 100 0.000454 0.000226 True 2.01\n", + "32 100 0.000403 0.000192 True 2.10\n", + "33 100 0.000374 0.000178 True 2.10\n", + "34 100 0.000353 0.000178 True 1.98\n", + "35 100 0.000342 0.000176 True 1.95\n", + "36 100 0.000324 0.000165 True 1.96\n", + "37 100 0.000318 0.000159 True 2.00\n", + "38 100 0.000314 0.000156 True 2.01\n", + "39 100 0.000316 0.000160 True 1.98\n", + "40 1000 0.013394 0.003922 True 3.42\n", + "41 1000 0.015261 0.002053 True 7.43\n", + "42 1000 0.012111 0.002061 True 5.88\n", + "43 1000 0.012727 0.001787 True 7.12\n", + "44 1000 0.011358 0.001635 True 6.94\n", + "45 1000 0.011545 0.001832 True 6.30\n", + "46 1000 0.011723 0.001660 True 7.06\n", + "47 1000 0.011872 0.001853 True 6.41\n", + "48 1000 0.011764 0.001686 True 6.98\n", + "49 1000 0.012014 0.001991 True 6.04\n", + "50 10000 1.018952 0.099710 True 10.22\n", + "51 10000 1.020243 0.107800 True 9.46\n", + "52 10000 1.032524 0.098206 True 10.51\n", + "53 10000 1.051917 0.109932 True 9.57\n", + "54 10000 1.058021 0.097128 True 10.89\n", + "55 10000 1.010834 0.100011 True 10.11\n", + "56 10000 1.010619 0.101155 True 9.99\n", + "57 10000 1.395246 0.239068 True 5.84\n", + "58 10000 1.149318 0.101077 True 11.37\n", + "59 10000 0.989659 0.098101 True 10.09" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import pandas as pd\n", + "\n", + "df = pd.DataFrame(results)\n", + "\n", + "# Compare the results element-wise and aggregate the results\n", + "df[\"result_match\"] = df.apply(lambda row: np.allclose(row[\"sklearn_result\"], row[\"sklearnex_result\"]), axis=1)\n", + "df.drop(columns=[\"sklearn_result\", \"sklearnex_result\"], inplace=True)\n", + "\n", + "# Calculate the speedup\n", + "df[\"speedup\"] = df[\"sklearn_time\"] / df[\"sklearnex_time\"]\n", + "df[\"speedup\"] = df[\"speedup\"].apply(lambda x: round(x, 2))\n", + "\n", + "df" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Row SizeResult Match CountAverage Speedup
02103.017
15102.969
210103.103
3100101.968
41000106.358
510000109.805
\n", + "
" + ], + "text/plain": [ + " Row Size Result Match Count Average Speedup\n", + "0 2 10 3.017\n", + "1 5 10 2.969\n", + "2 10 10 3.103\n", + "3 100 10 1.968\n", + "4 1000 10 6.358\n", + "5 10000 10 9.805" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Calculate the number of times the results match and \n", + "# the average speedup for each row size\n", + "df_avg = df.groupby(\"# of row\").agg({\"result_match\": \"sum\", \"speedup\": \"mean\"}).reset_index()\n", + "df_avg.columns = [\"Row Size\", \"Result Match Count\", \"Average Speedup\"]\n", + "\n", + "df_avg" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary & Conclusion\n", + "\n", + "- The results indicate that sklearnex consistently produces the same results as scikit-learn (`result_match = True`).\n", + "- Performance improvements vary depending on input size, with a notable speedup as the number of rows increases.\n", + "- For large datasets (e.g., 10,000 rows), sklearnex achieves up to a **10x speedup** compared to scikit-learn.\n", + "- However, for smaller datasets (e.g., row size = 2), sklearnex sometimes shows **slower performance**." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv310", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/sklearnex/metrics/pairwise.py b/sklearnex/metrics/pairwise.py index 8ad789dce1..886d2cf565 100755 --- a/sklearnex/metrics/pairwise.py +++ b/sklearnex/metrics/pairwise.py @@ -16,7 +16,11 @@ from daal4py.sklearn.metrics import pairwise_distances from onedal._device_offload import support_input_format +from onedal.primitives import rbf_kernel as onedal_rbf_kernel pairwise_distances = support_input_format(freefunc=True, queue_param=False)( pairwise_distances ) + + +rbf_kernel = support_input_format(freefunc=True, queue_param=False)(onedal_rbf_kernel) diff --git a/sklearnex/metrics/tests/test_metrics.py b/sklearnex/metrics/tests/test_metrics.py index 010a6ef0e1..baad105dc6 100755 --- a/sklearnex/metrics/tests/test_metrics.py +++ b/sklearnex/metrics/tests/test_metrics.py @@ -37,3 +37,16 @@ def test_sklearnex_import_pairwise_distances(): x = np.vstack([x, x]) res = pairwise_distances(x, metric="cosine") assert_allclose(res, [[0.0, 0.0], [0.0, 0.0]], atol=1e-2) + + +def test_sklearnex_import_rbf_kernel(): + from sklearnex.metrics.pairwise import rbf_kernel + + rng = np.random.RandomState(0) + X = rng.rand(5, 3) + gamma = 0.5 + res = rbf_kernel(X, gamma=gamma) + expected_res = np.exp( + -gamma * np.sum((X[:, np.newaxis] - X[np.newaxis, :]) ** 2, axis=-1) + ) + assert_allclose(res, expected_res, atol=1e-6) From 66a4992b3152d084c1be684feacc9086dede12b4 Mon Sep 17 00:00:00 2001 From: "Tan, An Nie" Date: Thu, 13 Mar 2025 00:26:32 +0800 Subject: [PATCH 2/5] add RBFKernel class --- sklearnex/metrics/pairwise.py | 57 +++++++++++++++++++++++++++++++++-- 1 file changed, 55 insertions(+), 2 deletions(-) diff --git a/sklearnex/metrics/pairwise.py b/sklearnex/metrics/pairwise.py index 886d2cf565..a25e308618 100755 --- a/sklearnex/metrics/pairwise.py +++ b/sklearnex/metrics/pairwise.py @@ -14,13 +14,66 @@ # limitations under the License. # =============================================================================== +from sklearn.base import BaseEstimator +from sklearn.metrics.pairwise import rbf_kernel as _sklearn_rbf_kernel + +from daal4py.sklearn._utils import sklearn_check_version from daal4py.sklearn.metrics import pairwise_distances from onedal._device_offload import support_input_format -from onedal.primitives import rbf_kernel as onedal_rbf_kernel +from onedal.primitives import rbf_kernel as _onedal_rbf_kernel + +from .._device_offload import dispatch +from .._utils import PatchingConditionsChain pairwise_distances = support_input_format(freefunc=True, queue_param=False)( pairwise_distances ) -rbf_kernel = support_input_format(freefunc=True, queue_param=False)(onedal_rbf_kernel) +if sklearn_check_version("1.6"): + from sklearn.utils.validation import validate_data +else: + validate_data = BaseEstimator._validate_data + + +class RBFKernel: + __doc__ = _sklearn_rbf_kernel.__doc__ + + def __init__(self): + pass + + def _onedal_supported(self, method_name, *data): + patching_status = PatchingConditionsChain( + f"sklearn.metrics.pairwise.{method_name}" + ) + print(f"patching_status.get_status() = {patching_status.get_status()}") + return patching_status + + def _onedal_cpu_supported(self, method_name, *data): + return self._onedal_supported(method_name, *data) + + def _onedal_gpu_supported(self, method_name, *data): + return self._onedal_supported(method_name, *data) + + def _onedal_rbf_kernel(self, X, Y=None, gamma=None, queue=None): + print(f"queue = {queue}") + return _onedal_rbf_kernel(X, Y, gamma, queue) + + def compute(self, X, Y=None, gamma=None): + result = dispatch( + self, + "rbf_kernel", + { + "onedal": self.__class__._onedal_rbf_kernel, + "sklearn": _sklearn_rbf_kernel, + }, + X, + Y, + gamma, + ) + + return result + + +def rbf_kernel(X, Y=None, gamma=None): + return RBFKernel().compute(X, Y, gamma) From de30de7cb5e9730e4cf8261f1e7ac64d0b1d7ef7 Mon Sep 17 00:00:00 2001 From: "Tan, An Nie" Date: Thu, 13 Mar 2025 00:29:18 +0800 Subject: [PATCH 3/5] remove print --- sklearnex/metrics/pairwise.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sklearnex/metrics/pairwise.py b/sklearnex/metrics/pairwise.py index a25e308618..ea87c295bf 100755 --- a/sklearnex/metrics/pairwise.py +++ b/sklearnex/metrics/pairwise.py @@ -46,7 +46,6 @@ def _onedal_supported(self, method_name, *data): patching_status = PatchingConditionsChain( f"sklearn.metrics.pairwise.{method_name}" ) - print(f"patching_status.get_status() = {patching_status.get_status()}") return patching_status def _onedal_cpu_supported(self, method_name, *data): @@ -56,7 +55,6 @@ def _onedal_gpu_supported(self, method_name, *data): return self._onedal_supported(method_name, *data) def _onedal_rbf_kernel(self, X, Y=None, gamma=None, queue=None): - print(f"queue = {queue}") return _onedal_rbf_kernel(X, Y, gamma, queue) def compute(self, X, Y=None, gamma=None): From 11db635a3491f8214a7992f6c2e0e6fc268125be Mon Sep 17 00:00:00 2001 From: "Tan, An Nie" Date: Thu, 13 Mar 2025 08:43:37 +0800 Subject: [PATCH 4/5] remove unused docstring --- sklearnex/metrics/pairwise.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sklearnex/metrics/pairwise.py b/sklearnex/metrics/pairwise.py index ea87c295bf..b9913fc923 100755 --- a/sklearnex/metrics/pairwise.py +++ b/sklearnex/metrics/pairwise.py @@ -37,7 +37,6 @@ class RBFKernel: - __doc__ = _sklearn_rbf_kernel.__doc__ def __init__(self): pass From 698e604c392161274da9beca9a1f9184efa075b9 Mon Sep 17 00:00:00 2001 From: Tan An Nie <121005973+tanannie22@users.noreply.github.com> Date: Fri, 11 Jul 2025 15:43:21 +0800 Subject: [PATCH 5/5] Update pairwise.py --- sklearnex/metrics/pairwise.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/sklearnex/metrics/pairwise.py b/sklearnex/metrics/pairwise.py index 5ba8c809fd..9821af87fc 100755 --- a/sklearnex/metrics/pairwise.py +++ b/sklearnex/metrics/pairwise.py @@ -25,9 +25,7 @@ from .._device_offload import dispatch from .._utils import PatchingConditionsChain -pairwise_distances = support_input_format(freefunc=True, queue_param=False)( - pairwise_distances -) +pairwise_distances = support_input_format(pairwise_distances) if sklearn_check_version("1.6"): @@ -74,4 +72,3 @@ def compute(self, X, Y=None, gamma=None): def rbf_kernel(X, Y=None, gamma=None): return RBFKernel().compute(X, Y, gamma) -