diff --git a/examples/notebooks/benchmark_rbf_kernel.ipynb b/examples/notebooks/benchmark_rbf_kernel.ipynb new file mode 100644 index 0000000000..c12c58a8ab --- /dev/null +++ b/examples/notebooks/benchmark_rbf_kernel.ipynb @@ -0,0 +1,853 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Benchmarking RBF Kernel Performance\n", + "\n", + "This notebook compares the execution time of the RBF kernel function in scikit-learn and sklearnex. The goal is to evaluate the performance improvements using the oneDAL-optimized implementation.\n", + "\n", + "### Methodology\n", + "- We generate random matrices of different row sizes (2, 5, 10, 100, 1000, 10000) with 3 columns.\n", + "- We measure the execution time of the RBF kernel function in scikit-learn and sklearnex.\n", + "- We compare the results for accuracy and compute the speedup factor.\n", + "\n", + "### Setup & Environment" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "### System Information ###\n", + "OS: Windows 10 (10.0.22631)\n", + "Windows Edition: Windows 11 Enterprise\n", + "Processor: Intel64 Family 6 Model 154 Stepping 3, GenuineIntel\n", + "CPU Cores: 12 (Physical), 16 (Logical)\n", + "RAM: 63.64 GB\n", + "Graphics Card: Intel(R) Iris(R) Xe Graphics\n", + "\n", + "### Library Versions ###\n", + "Python Version: 3.10.11\n", + "scikit-learn Version: 1.6.1\n", + "NumPy Version: 1.26.4\n", + "Pandas Version: 2.1.3\n" + ] + } + ], + "source": [ + "import platform\n", + "import psutil\n", + "import sklearn\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "print(\"### System Information ###\")\n", + "print(f\"OS: {platform.system()} {platform.release()} ({platform.version()})\")\n", + "print(\"Windows Edition: Windows 11 Enterprise\")\n", + "print(f\"Processor: {platform.processor()}\")\n", + "print(f\"CPU Cores: {psutil.cpu_count(logical=False)} (Physical), {psutil.cpu_count(logical=True)} (Logical)\")\n", + "print(f\"RAM: {round(psutil.virtual_memory().total / (1024**3), 2)} GB\")\n", + "print(\"Graphics Card: Intel(R) Iris(R) Xe Graphics\")\n", + "\n", + "print(\"\\n### Library Versions ###\")\n", + "print(f\"Python Version: {platform.python_version()}\")\n", + "print(f\"scikit-learn Version: {sklearn.__version__}\")\n", + "print(f\"NumPy Version: {np.__version__}\")\n", + "print(f\"Pandas Version: {pd.__version__}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from timeit import default_timer as timer\n", + "from sklearn.metrics.pairwise import rbf_kernel as sklearn_rbf_kernel\n", + "from sklearnex.metrics.pairwise import rbf_kernel as sklearnex_rbf_kernel\n", + "import numpy as np\n", + "\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\")\n", + "\n", + "# a list of different sizes of data to be tested\n", + "row_sizes = [2, 5, 10, 100, 1000, 10000]\n", + "col_size = 3\n", + "\n", + "results = []\n", + "\n", + "for row in row_sizes:\n", + " X = np.random.rand(row, col_size)\n", + " Y = np.random.rand(row, col_size)\n", + "\n", + " for i in range(10):\n", + " start = timer()\n", + " sklearn_result = sklearn_rbf_kernel(X, Y)\n", + " sklearn_time = timer() - start\n", + "\n", + " start = timer()\n", + " sklearnex_result = sklearnex_rbf_kernel(X, Y)\n", + " sklearnex_time = timer() - start\n", + "\n", + " results.append({\n", + " \"# of row\": row,\n", + " \"sklearn_result\": sklearn_result,\n", + " \"sklearnex_result\": sklearnex_result,\n", + " \"sklearn_time\": sklearn_time,\n", + " \"sklearnex_time\": sklearnex_time\n", + " })" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Results\n", + "\n", + "The results are stored in a dataframe for easier interpretation and analysis.\n", + "\n", + "- The `result_match` column verifies if both implementations produce identical results.\n", + "- The `speedup` column shows how much faster sklearnex is compared to scikit-learn." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
# of rowsklearn_timesklearnex_timeresult_matchspeedup
020.0005750.002887True0.20
120.0007020.000182True3.86
220.0005320.000176True3.03
320.0005100.000155True3.29
420.0005060.000150True3.38
520.0004770.000146True3.27
620.0004700.000144True3.28
720.0004710.000143True3.29
820.0004680.000143True3.27
920.0004870.000148True3.30
1050.0004770.000170True2.80
1150.0005490.000184True2.98
1250.0003280.000102True3.22
1350.0003070.000107True2.88
1450.0003100.000140True2.22
1550.0003530.000111True3.17
1650.0003450.000112True3.08
1750.0003410.000108True3.17
1850.0003580.000124True2.90
1950.0003880.000119True3.27
20100.0003980.000152True2.61
21100.0004300.000142True3.02
22100.0004170.000130True3.21
23100.0004100.000160True2.57
24100.0003980.000130True3.06
25100.0004020.000134True3.01
26100.0004290.000127True3.37
27100.0004050.000120True3.38
28100.0003770.000116True3.24
29100.0004000.000112True3.56
301000.0005960.000374True1.59
311000.0004540.000226True2.01
321000.0004030.000192True2.10
331000.0003740.000178True2.10
341000.0003530.000178True1.98
351000.0003420.000176True1.95
361000.0003240.000165True1.96
371000.0003180.000159True2.00
381000.0003140.000156True2.01
391000.0003160.000160True1.98
4010000.0133940.003922True3.42
4110000.0152610.002053True7.43
4210000.0121110.002061True5.88
4310000.0127270.001787True7.12
4410000.0113580.001635True6.94
4510000.0115450.001832True6.30
4610000.0117230.001660True7.06
4710000.0118720.001853True6.41
4810000.0117640.001686True6.98
4910000.0120140.001991True6.04
50100001.0189520.099710True10.22
51100001.0202430.107800True9.46
52100001.0325240.098206True10.51
53100001.0519170.109932True9.57
54100001.0580210.097128True10.89
55100001.0108340.100011True10.11
56100001.0106190.101155True9.99
57100001.3952460.239068True5.84
58100001.1493180.101077True11.37
59100000.9896590.098101True10.09
\n", + "
" + ], + "text/plain": [ + " # of row sklearn_time sklearnex_time result_match speedup\n", + "0 2 0.000575 0.002887 True 0.20\n", + "1 2 0.000702 0.000182 True 3.86\n", + "2 2 0.000532 0.000176 True 3.03\n", + "3 2 0.000510 0.000155 True 3.29\n", + "4 2 0.000506 0.000150 True 3.38\n", + "5 2 0.000477 0.000146 True 3.27\n", + "6 2 0.000470 0.000144 True 3.28\n", + "7 2 0.000471 0.000143 True 3.29\n", + "8 2 0.000468 0.000143 True 3.27\n", + "9 2 0.000487 0.000148 True 3.30\n", + "10 5 0.000477 0.000170 True 2.80\n", + "11 5 0.000549 0.000184 True 2.98\n", + "12 5 0.000328 0.000102 True 3.22\n", + "13 5 0.000307 0.000107 True 2.88\n", + "14 5 0.000310 0.000140 True 2.22\n", + "15 5 0.000353 0.000111 True 3.17\n", + "16 5 0.000345 0.000112 True 3.08\n", + "17 5 0.000341 0.000108 True 3.17\n", + "18 5 0.000358 0.000124 True 2.90\n", + "19 5 0.000388 0.000119 True 3.27\n", + "20 10 0.000398 0.000152 True 2.61\n", + "21 10 0.000430 0.000142 True 3.02\n", + "22 10 0.000417 0.000130 True 3.21\n", + "23 10 0.000410 0.000160 True 2.57\n", + "24 10 0.000398 0.000130 True 3.06\n", + "25 10 0.000402 0.000134 True 3.01\n", + "26 10 0.000429 0.000127 True 3.37\n", + "27 10 0.000405 0.000120 True 3.38\n", + "28 10 0.000377 0.000116 True 3.24\n", + "29 10 0.000400 0.000112 True 3.56\n", + "30 100 0.000596 0.000374 True 1.59\n", + "31 100 0.000454 0.000226 True 2.01\n", + "32 100 0.000403 0.000192 True 2.10\n", + "33 100 0.000374 0.000178 True 2.10\n", + "34 100 0.000353 0.000178 True 1.98\n", + "35 100 0.000342 0.000176 True 1.95\n", + "36 100 0.000324 0.000165 True 1.96\n", + "37 100 0.000318 0.000159 True 2.00\n", + "38 100 0.000314 0.000156 True 2.01\n", + "39 100 0.000316 0.000160 True 1.98\n", + "40 1000 0.013394 0.003922 True 3.42\n", + "41 1000 0.015261 0.002053 True 7.43\n", + "42 1000 0.012111 0.002061 True 5.88\n", + "43 1000 0.012727 0.001787 True 7.12\n", + "44 1000 0.011358 0.001635 True 6.94\n", + "45 1000 0.011545 0.001832 True 6.30\n", + "46 1000 0.011723 0.001660 True 7.06\n", + "47 1000 0.011872 0.001853 True 6.41\n", + "48 1000 0.011764 0.001686 True 6.98\n", + "49 1000 0.012014 0.001991 True 6.04\n", + "50 10000 1.018952 0.099710 True 10.22\n", + "51 10000 1.020243 0.107800 True 9.46\n", + "52 10000 1.032524 0.098206 True 10.51\n", + "53 10000 1.051917 0.109932 True 9.57\n", + "54 10000 1.058021 0.097128 True 10.89\n", + "55 10000 1.010834 0.100011 True 10.11\n", + "56 10000 1.010619 0.101155 True 9.99\n", + "57 10000 1.395246 0.239068 True 5.84\n", + "58 10000 1.149318 0.101077 True 11.37\n", + "59 10000 0.989659 0.098101 True 10.09" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import pandas as pd\n", + "\n", + "df = pd.DataFrame(results)\n", + "\n", + "# Compare the results element-wise and aggregate the results\n", + "df[\"result_match\"] = df.apply(lambda row: np.allclose(row[\"sklearn_result\"], row[\"sklearnex_result\"]), axis=1)\n", + "df.drop(columns=[\"sklearn_result\", \"sklearnex_result\"], inplace=True)\n", + "\n", + "# Calculate the speedup\n", + "df[\"speedup\"] = df[\"sklearn_time\"] / df[\"sklearnex_time\"]\n", + "df[\"speedup\"] = df[\"speedup\"].apply(lambda x: round(x, 2))\n", + "\n", + "df" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Row SizeResult Match CountAverage Speedup
02103.017
15102.969
210103.103
3100101.968
41000106.358
510000109.805
\n", + "
" + ], + "text/plain": [ + " Row Size Result Match Count Average Speedup\n", + "0 2 10 3.017\n", + "1 5 10 2.969\n", + "2 10 10 3.103\n", + "3 100 10 1.968\n", + "4 1000 10 6.358\n", + "5 10000 10 9.805" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Calculate the number of times the results match and \n", + "# the average speedup for each row size\n", + "df_avg = df.groupby(\"# of row\").agg({\"result_match\": \"sum\", \"speedup\": \"mean\"}).reset_index()\n", + "df_avg.columns = [\"Row Size\", \"Result Match Count\", \"Average Speedup\"]\n", + "\n", + "df_avg" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary & Conclusion\n", + "\n", + "- The results indicate that sklearnex consistently produces the same results as scikit-learn (`result_match = True`).\n", + "- Performance improvements vary depending on input size, with a notable speedup as the number of rows increases.\n", + "- For large datasets (e.g., 10,000 rows), sklearnex achieves up to a **10x speedup** compared to scikit-learn.\n", + "- However, for smaller datasets (e.g., row size = 2), sklearnex sometimes shows **slower performance**." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv310", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/sklearnex/metrics/pairwise.py b/sklearnex/metrics/pairwise.py index ffcc136e1d..9821af87fc 100755 --- a/sklearnex/metrics/pairwise.py +++ b/sklearnex/metrics/pairwise.py @@ -14,7 +14,61 @@ # limitations under the License. # =============================================================================== +from sklearn.base import BaseEstimator +from sklearn.metrics.pairwise import rbf_kernel as _sklearn_rbf_kernel + +from daal4py.sklearn._utils import sklearn_check_version from daal4py.sklearn.metrics import pairwise_distances from onedal._device_offload import support_input_format +from onedal.primitives import rbf_kernel as _onedal_rbf_kernel + +from .._device_offload import dispatch +from .._utils import PatchingConditionsChain pairwise_distances = support_input_format(pairwise_distances) + + +if sklearn_check_version("1.6"): + from sklearn.utils.validation import validate_data +else: + validate_data = BaseEstimator._validate_data + + +class RBFKernel: + + def __init__(self): + pass + + def _onedal_supported(self, method_name, *data): + patching_status = PatchingConditionsChain( + f"sklearn.metrics.pairwise.{method_name}" + ) + return patching_status + + def _onedal_cpu_supported(self, method_name, *data): + return self._onedal_supported(method_name, *data) + + def _onedal_gpu_supported(self, method_name, *data): + return self._onedal_supported(method_name, *data) + + def _onedal_rbf_kernel(self, X, Y=None, gamma=None, queue=None): + return _onedal_rbf_kernel(X, Y, gamma, queue) + + def compute(self, X, Y=None, gamma=None): + result = dispatch( + self, + "rbf_kernel", + { + "onedal": self.__class__._onedal_rbf_kernel, + "sklearn": _sklearn_rbf_kernel, + }, + X, + Y, + gamma, + ) + + return result + + +def rbf_kernel(X, Y=None, gamma=None): + return RBFKernel().compute(X, Y, gamma) diff --git a/sklearnex/metrics/tests/test_metrics.py b/sklearnex/metrics/tests/test_metrics.py index 010a6ef0e1..baad105dc6 100755 --- a/sklearnex/metrics/tests/test_metrics.py +++ b/sklearnex/metrics/tests/test_metrics.py @@ -37,3 +37,16 @@ def test_sklearnex_import_pairwise_distances(): x = np.vstack([x, x]) res = pairwise_distances(x, metric="cosine") assert_allclose(res, [[0.0, 0.0], [0.0, 0.0]], atol=1e-2) + + +def test_sklearnex_import_rbf_kernel(): + from sklearnex.metrics.pairwise import rbf_kernel + + rng = np.random.RandomState(0) + X = rng.rand(5, 3) + gamma = 0.5 + res = rbf_kernel(X, gamma=gamma) + expected_res = np.exp( + -gamma * np.sum((X[:, np.newaxis] - X[np.newaxis, :]) ** 2, axis=-1) + ) + assert_allclose(res, expected_res, atol=1e-6)