diff --git a/examples/notebooks/benchmark_rbf_kernel.ipynb b/examples/notebooks/benchmark_rbf_kernel.ipynb
new file mode 100644
index 0000000000..c12c58a8ab
--- /dev/null
+++ b/examples/notebooks/benchmark_rbf_kernel.ipynb
@@ -0,0 +1,853 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Benchmarking RBF Kernel Performance\n",
+ "\n",
+ "This notebook compares the execution time of the RBF kernel function in scikit-learn and sklearnex. The goal is to evaluate the performance improvements using the oneDAL-optimized implementation.\n",
+ "\n",
+ "### Methodology\n",
+ "- We generate random matrices of different row sizes (2, 5, 10, 100, 1000, 10000) with 3 columns.\n",
+ "- We measure the execution time of the RBF kernel function in scikit-learn and sklearnex.\n",
+ "- We compare the results for accuracy and compute the speedup factor.\n",
+ "\n",
+ "### Setup & Environment"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "### System Information ###\n",
+ "OS: Windows 10 (10.0.22631)\n",
+ "Windows Edition: Windows 11 Enterprise\n",
+ "Processor: Intel64 Family 6 Model 154 Stepping 3, GenuineIntel\n",
+ "CPU Cores: 12 (Physical), 16 (Logical)\n",
+ "RAM: 63.64 GB\n",
+ "Graphics Card: Intel(R) Iris(R) Xe Graphics\n",
+ "\n",
+ "### Library Versions ###\n",
+ "Python Version: 3.10.11\n",
+ "scikit-learn Version: 1.6.1\n",
+ "NumPy Version: 1.26.4\n",
+ "Pandas Version: 2.1.3\n"
+ ]
+ }
+ ],
+ "source": [
+ "import platform\n",
+ "import psutil\n",
+ "import sklearn\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "\n",
+ "print(\"### System Information ###\")\n",
+ "print(f\"OS: {platform.system()} {platform.release()} ({platform.version()})\")\n",
+ "print(\"Windows Edition: Windows 11 Enterprise\")\n",
+ "print(f\"Processor: {platform.processor()}\")\n",
+ "print(f\"CPU Cores: {psutil.cpu_count(logical=False)} (Physical), {psutil.cpu_count(logical=True)} (Logical)\")\n",
+ "print(f\"RAM: {round(psutil.virtual_memory().total / (1024**3), 2)} GB\")\n",
+ "print(\"Graphics Card: Intel(R) Iris(R) Xe Graphics\")\n",
+ "\n",
+ "print(\"\\n### Library Versions ###\")\n",
+ "print(f\"Python Version: {platform.python_version()}\")\n",
+ "print(f\"scikit-learn Version: {sklearn.__version__}\")\n",
+ "print(f\"NumPy Version: {np.__version__}\")\n",
+ "print(f\"Pandas Version: {pd.__version__}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from timeit import default_timer as timer\n",
+ "from sklearn.metrics.pairwise import rbf_kernel as sklearn_rbf_kernel\n",
+ "from sklearnex.metrics.pairwise import rbf_kernel as sklearnex_rbf_kernel\n",
+ "import numpy as np\n",
+ "\n",
+ "import warnings\n",
+ "warnings.filterwarnings(\"ignore\")\n",
+ "\n",
+ "# a list of different sizes of data to be tested\n",
+ "row_sizes = [2, 5, 10, 100, 1000, 10000]\n",
+ "col_size = 3\n",
+ "\n",
+ "results = []\n",
+ "\n",
+ "for row in row_sizes:\n",
+ " X = np.random.rand(row, col_size)\n",
+ " Y = np.random.rand(row, col_size)\n",
+ "\n",
+ " for i in range(10):\n",
+ " start = timer()\n",
+ " sklearn_result = sklearn_rbf_kernel(X, Y)\n",
+ " sklearn_time = timer() - start\n",
+ "\n",
+ " start = timer()\n",
+ " sklearnex_result = sklearnex_rbf_kernel(X, Y)\n",
+ " sklearnex_time = timer() - start\n",
+ "\n",
+ " results.append({\n",
+ " \"# of row\": row,\n",
+ " \"sklearn_result\": sklearn_result,\n",
+ " \"sklearnex_result\": sklearnex_result,\n",
+ " \"sklearn_time\": sklearn_time,\n",
+ " \"sklearnex_time\": sklearnex_time\n",
+ " })"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Results\n",
+ "\n",
+ "The results are stored in a dataframe for easier interpretation and analysis.\n",
+ "\n",
+ "- The `result_match` column verifies if both implementations produce identical results.\n",
+ "- The `speedup` column shows how much faster sklearnex is compared to scikit-learn."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " # of row | \n",
+ " sklearn_time | \n",
+ " sklearnex_time | \n",
+ " result_match | \n",
+ " speedup | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 2 | \n",
+ " 0.000575 | \n",
+ " 0.002887 | \n",
+ " True | \n",
+ " 0.20 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 2 | \n",
+ " 0.000702 | \n",
+ " 0.000182 | \n",
+ " True | \n",
+ " 3.86 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 2 | \n",
+ " 0.000532 | \n",
+ " 0.000176 | \n",
+ " True | \n",
+ " 3.03 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 2 | \n",
+ " 0.000510 | \n",
+ " 0.000155 | \n",
+ " True | \n",
+ " 3.29 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 2 | \n",
+ " 0.000506 | \n",
+ " 0.000150 | \n",
+ " True | \n",
+ " 3.38 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 2 | \n",
+ " 0.000477 | \n",
+ " 0.000146 | \n",
+ " True | \n",
+ " 3.27 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 2 | \n",
+ " 0.000470 | \n",
+ " 0.000144 | \n",
+ " True | \n",
+ " 3.28 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 2 | \n",
+ " 0.000471 | \n",
+ " 0.000143 | \n",
+ " True | \n",
+ " 3.29 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 2 | \n",
+ " 0.000468 | \n",
+ " 0.000143 | \n",
+ " True | \n",
+ " 3.27 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 2 | \n",
+ " 0.000487 | \n",
+ " 0.000148 | \n",
+ " True | \n",
+ " 3.30 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 5 | \n",
+ " 0.000477 | \n",
+ " 0.000170 | \n",
+ " True | \n",
+ " 2.80 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 5 | \n",
+ " 0.000549 | \n",
+ " 0.000184 | \n",
+ " True | \n",
+ " 2.98 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 5 | \n",
+ " 0.000328 | \n",
+ " 0.000102 | \n",
+ " True | \n",
+ " 3.22 | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " 5 | \n",
+ " 0.000307 | \n",
+ " 0.000107 | \n",
+ " True | \n",
+ " 2.88 | \n",
+ "
\n",
+ " \n",
+ " 14 | \n",
+ " 5 | \n",
+ " 0.000310 | \n",
+ " 0.000140 | \n",
+ " True | \n",
+ " 2.22 | \n",
+ "
\n",
+ " \n",
+ " 15 | \n",
+ " 5 | \n",
+ " 0.000353 | \n",
+ " 0.000111 | \n",
+ " True | \n",
+ " 3.17 | \n",
+ "
\n",
+ " \n",
+ " 16 | \n",
+ " 5 | \n",
+ " 0.000345 | \n",
+ " 0.000112 | \n",
+ " True | \n",
+ " 3.08 | \n",
+ "
\n",
+ " \n",
+ " 17 | \n",
+ " 5 | \n",
+ " 0.000341 | \n",
+ " 0.000108 | \n",
+ " True | \n",
+ " 3.17 | \n",
+ "
\n",
+ " \n",
+ " 18 | \n",
+ " 5 | \n",
+ " 0.000358 | \n",
+ " 0.000124 | \n",
+ " True | \n",
+ " 2.90 | \n",
+ "
\n",
+ " \n",
+ " 19 | \n",
+ " 5 | \n",
+ " 0.000388 | \n",
+ " 0.000119 | \n",
+ " True | \n",
+ " 3.27 | \n",
+ "
\n",
+ " \n",
+ " 20 | \n",
+ " 10 | \n",
+ " 0.000398 | \n",
+ " 0.000152 | \n",
+ " True | \n",
+ " 2.61 | \n",
+ "
\n",
+ " \n",
+ " 21 | \n",
+ " 10 | \n",
+ " 0.000430 | \n",
+ " 0.000142 | \n",
+ " True | \n",
+ " 3.02 | \n",
+ "
\n",
+ " \n",
+ " 22 | \n",
+ " 10 | \n",
+ " 0.000417 | \n",
+ " 0.000130 | \n",
+ " True | \n",
+ " 3.21 | \n",
+ "
\n",
+ " \n",
+ " 23 | \n",
+ " 10 | \n",
+ " 0.000410 | \n",
+ " 0.000160 | \n",
+ " True | \n",
+ " 2.57 | \n",
+ "
\n",
+ " \n",
+ " 24 | \n",
+ " 10 | \n",
+ " 0.000398 | \n",
+ " 0.000130 | \n",
+ " True | \n",
+ " 3.06 | \n",
+ "
\n",
+ " \n",
+ " 25 | \n",
+ " 10 | \n",
+ " 0.000402 | \n",
+ " 0.000134 | \n",
+ " True | \n",
+ " 3.01 | \n",
+ "
\n",
+ " \n",
+ " 26 | \n",
+ " 10 | \n",
+ " 0.000429 | \n",
+ " 0.000127 | \n",
+ " True | \n",
+ " 3.37 | \n",
+ "
\n",
+ " \n",
+ " 27 | \n",
+ " 10 | \n",
+ " 0.000405 | \n",
+ " 0.000120 | \n",
+ " True | \n",
+ " 3.38 | \n",
+ "
\n",
+ " \n",
+ " 28 | \n",
+ " 10 | \n",
+ " 0.000377 | \n",
+ " 0.000116 | \n",
+ " True | \n",
+ " 3.24 | \n",
+ "
\n",
+ " \n",
+ " 29 | \n",
+ " 10 | \n",
+ " 0.000400 | \n",
+ " 0.000112 | \n",
+ " True | \n",
+ " 3.56 | \n",
+ "
\n",
+ " \n",
+ " 30 | \n",
+ " 100 | \n",
+ " 0.000596 | \n",
+ " 0.000374 | \n",
+ " True | \n",
+ " 1.59 | \n",
+ "
\n",
+ " \n",
+ " 31 | \n",
+ " 100 | \n",
+ " 0.000454 | \n",
+ " 0.000226 | \n",
+ " True | \n",
+ " 2.01 | \n",
+ "
\n",
+ " \n",
+ " 32 | \n",
+ " 100 | \n",
+ " 0.000403 | \n",
+ " 0.000192 | \n",
+ " True | \n",
+ " 2.10 | \n",
+ "
\n",
+ " \n",
+ " 33 | \n",
+ " 100 | \n",
+ " 0.000374 | \n",
+ " 0.000178 | \n",
+ " True | \n",
+ " 2.10 | \n",
+ "
\n",
+ " \n",
+ " 34 | \n",
+ " 100 | \n",
+ " 0.000353 | \n",
+ " 0.000178 | \n",
+ " True | \n",
+ " 1.98 | \n",
+ "
\n",
+ " \n",
+ " 35 | \n",
+ " 100 | \n",
+ " 0.000342 | \n",
+ " 0.000176 | \n",
+ " True | \n",
+ " 1.95 | \n",
+ "
\n",
+ " \n",
+ " 36 | \n",
+ " 100 | \n",
+ " 0.000324 | \n",
+ " 0.000165 | \n",
+ " True | \n",
+ " 1.96 | \n",
+ "
\n",
+ " \n",
+ " 37 | \n",
+ " 100 | \n",
+ " 0.000318 | \n",
+ " 0.000159 | \n",
+ " True | \n",
+ " 2.00 | \n",
+ "
\n",
+ " \n",
+ " 38 | \n",
+ " 100 | \n",
+ " 0.000314 | \n",
+ " 0.000156 | \n",
+ " True | \n",
+ " 2.01 | \n",
+ "
\n",
+ " \n",
+ " 39 | \n",
+ " 100 | \n",
+ " 0.000316 | \n",
+ " 0.000160 | \n",
+ " True | \n",
+ " 1.98 | \n",
+ "
\n",
+ " \n",
+ " 40 | \n",
+ " 1000 | \n",
+ " 0.013394 | \n",
+ " 0.003922 | \n",
+ " True | \n",
+ " 3.42 | \n",
+ "
\n",
+ " \n",
+ " 41 | \n",
+ " 1000 | \n",
+ " 0.015261 | \n",
+ " 0.002053 | \n",
+ " True | \n",
+ " 7.43 | \n",
+ "
\n",
+ " \n",
+ " 42 | \n",
+ " 1000 | \n",
+ " 0.012111 | \n",
+ " 0.002061 | \n",
+ " True | \n",
+ " 5.88 | \n",
+ "
\n",
+ " \n",
+ " 43 | \n",
+ " 1000 | \n",
+ " 0.012727 | \n",
+ " 0.001787 | \n",
+ " True | \n",
+ " 7.12 | \n",
+ "
\n",
+ " \n",
+ " 44 | \n",
+ " 1000 | \n",
+ " 0.011358 | \n",
+ " 0.001635 | \n",
+ " True | \n",
+ " 6.94 | \n",
+ "
\n",
+ " \n",
+ " 45 | \n",
+ " 1000 | \n",
+ " 0.011545 | \n",
+ " 0.001832 | \n",
+ " True | \n",
+ " 6.30 | \n",
+ "
\n",
+ " \n",
+ " 46 | \n",
+ " 1000 | \n",
+ " 0.011723 | \n",
+ " 0.001660 | \n",
+ " True | \n",
+ " 7.06 | \n",
+ "
\n",
+ " \n",
+ " 47 | \n",
+ " 1000 | \n",
+ " 0.011872 | \n",
+ " 0.001853 | \n",
+ " True | \n",
+ " 6.41 | \n",
+ "
\n",
+ " \n",
+ " 48 | \n",
+ " 1000 | \n",
+ " 0.011764 | \n",
+ " 0.001686 | \n",
+ " True | \n",
+ " 6.98 | \n",
+ "
\n",
+ " \n",
+ " 49 | \n",
+ " 1000 | \n",
+ " 0.012014 | \n",
+ " 0.001991 | \n",
+ " True | \n",
+ " 6.04 | \n",
+ "
\n",
+ " \n",
+ " 50 | \n",
+ " 10000 | \n",
+ " 1.018952 | \n",
+ " 0.099710 | \n",
+ " True | \n",
+ " 10.22 | \n",
+ "
\n",
+ " \n",
+ " 51 | \n",
+ " 10000 | \n",
+ " 1.020243 | \n",
+ " 0.107800 | \n",
+ " True | \n",
+ " 9.46 | \n",
+ "
\n",
+ " \n",
+ " 52 | \n",
+ " 10000 | \n",
+ " 1.032524 | \n",
+ " 0.098206 | \n",
+ " True | \n",
+ " 10.51 | \n",
+ "
\n",
+ " \n",
+ " 53 | \n",
+ " 10000 | \n",
+ " 1.051917 | \n",
+ " 0.109932 | \n",
+ " True | \n",
+ " 9.57 | \n",
+ "
\n",
+ " \n",
+ " 54 | \n",
+ " 10000 | \n",
+ " 1.058021 | \n",
+ " 0.097128 | \n",
+ " True | \n",
+ " 10.89 | \n",
+ "
\n",
+ " \n",
+ " 55 | \n",
+ " 10000 | \n",
+ " 1.010834 | \n",
+ " 0.100011 | \n",
+ " True | \n",
+ " 10.11 | \n",
+ "
\n",
+ " \n",
+ " 56 | \n",
+ " 10000 | \n",
+ " 1.010619 | \n",
+ " 0.101155 | \n",
+ " True | \n",
+ " 9.99 | \n",
+ "
\n",
+ " \n",
+ " 57 | \n",
+ " 10000 | \n",
+ " 1.395246 | \n",
+ " 0.239068 | \n",
+ " True | \n",
+ " 5.84 | \n",
+ "
\n",
+ " \n",
+ " 58 | \n",
+ " 10000 | \n",
+ " 1.149318 | \n",
+ " 0.101077 | \n",
+ " True | \n",
+ " 11.37 | \n",
+ "
\n",
+ " \n",
+ " 59 | \n",
+ " 10000 | \n",
+ " 0.989659 | \n",
+ " 0.098101 | \n",
+ " True | \n",
+ " 10.09 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " # of row sklearn_time sklearnex_time result_match speedup\n",
+ "0 2 0.000575 0.002887 True 0.20\n",
+ "1 2 0.000702 0.000182 True 3.86\n",
+ "2 2 0.000532 0.000176 True 3.03\n",
+ "3 2 0.000510 0.000155 True 3.29\n",
+ "4 2 0.000506 0.000150 True 3.38\n",
+ "5 2 0.000477 0.000146 True 3.27\n",
+ "6 2 0.000470 0.000144 True 3.28\n",
+ "7 2 0.000471 0.000143 True 3.29\n",
+ "8 2 0.000468 0.000143 True 3.27\n",
+ "9 2 0.000487 0.000148 True 3.30\n",
+ "10 5 0.000477 0.000170 True 2.80\n",
+ "11 5 0.000549 0.000184 True 2.98\n",
+ "12 5 0.000328 0.000102 True 3.22\n",
+ "13 5 0.000307 0.000107 True 2.88\n",
+ "14 5 0.000310 0.000140 True 2.22\n",
+ "15 5 0.000353 0.000111 True 3.17\n",
+ "16 5 0.000345 0.000112 True 3.08\n",
+ "17 5 0.000341 0.000108 True 3.17\n",
+ "18 5 0.000358 0.000124 True 2.90\n",
+ "19 5 0.000388 0.000119 True 3.27\n",
+ "20 10 0.000398 0.000152 True 2.61\n",
+ "21 10 0.000430 0.000142 True 3.02\n",
+ "22 10 0.000417 0.000130 True 3.21\n",
+ "23 10 0.000410 0.000160 True 2.57\n",
+ "24 10 0.000398 0.000130 True 3.06\n",
+ "25 10 0.000402 0.000134 True 3.01\n",
+ "26 10 0.000429 0.000127 True 3.37\n",
+ "27 10 0.000405 0.000120 True 3.38\n",
+ "28 10 0.000377 0.000116 True 3.24\n",
+ "29 10 0.000400 0.000112 True 3.56\n",
+ "30 100 0.000596 0.000374 True 1.59\n",
+ "31 100 0.000454 0.000226 True 2.01\n",
+ "32 100 0.000403 0.000192 True 2.10\n",
+ "33 100 0.000374 0.000178 True 2.10\n",
+ "34 100 0.000353 0.000178 True 1.98\n",
+ "35 100 0.000342 0.000176 True 1.95\n",
+ "36 100 0.000324 0.000165 True 1.96\n",
+ "37 100 0.000318 0.000159 True 2.00\n",
+ "38 100 0.000314 0.000156 True 2.01\n",
+ "39 100 0.000316 0.000160 True 1.98\n",
+ "40 1000 0.013394 0.003922 True 3.42\n",
+ "41 1000 0.015261 0.002053 True 7.43\n",
+ "42 1000 0.012111 0.002061 True 5.88\n",
+ "43 1000 0.012727 0.001787 True 7.12\n",
+ "44 1000 0.011358 0.001635 True 6.94\n",
+ "45 1000 0.011545 0.001832 True 6.30\n",
+ "46 1000 0.011723 0.001660 True 7.06\n",
+ "47 1000 0.011872 0.001853 True 6.41\n",
+ "48 1000 0.011764 0.001686 True 6.98\n",
+ "49 1000 0.012014 0.001991 True 6.04\n",
+ "50 10000 1.018952 0.099710 True 10.22\n",
+ "51 10000 1.020243 0.107800 True 9.46\n",
+ "52 10000 1.032524 0.098206 True 10.51\n",
+ "53 10000 1.051917 0.109932 True 9.57\n",
+ "54 10000 1.058021 0.097128 True 10.89\n",
+ "55 10000 1.010834 0.100011 True 10.11\n",
+ "56 10000 1.010619 0.101155 True 9.99\n",
+ "57 10000 1.395246 0.239068 True 5.84\n",
+ "58 10000 1.149318 0.101077 True 11.37\n",
+ "59 10000 0.989659 0.098101 True 10.09"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import pandas as pd\n",
+ "\n",
+ "df = pd.DataFrame(results)\n",
+ "\n",
+ "# Compare the results element-wise and aggregate the results\n",
+ "df[\"result_match\"] = df.apply(lambda row: np.allclose(row[\"sklearn_result\"], row[\"sklearnex_result\"]), axis=1)\n",
+ "df.drop(columns=[\"sklearn_result\", \"sklearnex_result\"], inplace=True)\n",
+ "\n",
+ "# Calculate the speedup\n",
+ "df[\"speedup\"] = df[\"sklearn_time\"] / df[\"sklearnex_time\"]\n",
+ "df[\"speedup\"] = df[\"speedup\"].apply(lambda x: round(x, 2))\n",
+ "\n",
+ "df"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " Row Size | \n",
+ " Result Match Count | \n",
+ " Average Speedup | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 2 | \n",
+ " 10 | \n",
+ " 3.017 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 5 | \n",
+ " 10 | \n",
+ " 2.969 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 10 | \n",
+ " 10 | \n",
+ " 3.103 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 100 | \n",
+ " 10 | \n",
+ " 1.968 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 1000 | \n",
+ " 10 | \n",
+ " 6.358 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 10000 | \n",
+ " 10 | \n",
+ " 9.805 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Row Size Result Match Count Average Speedup\n",
+ "0 2 10 3.017\n",
+ "1 5 10 2.969\n",
+ "2 10 10 3.103\n",
+ "3 100 10 1.968\n",
+ "4 1000 10 6.358\n",
+ "5 10000 10 9.805"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Calculate the number of times the results match and \n",
+ "# the average speedup for each row size\n",
+ "df_avg = df.groupby(\"# of row\").agg({\"result_match\": \"sum\", \"speedup\": \"mean\"}).reset_index()\n",
+ "df_avg.columns = [\"Row Size\", \"Result Match Count\", \"Average Speedup\"]\n",
+ "\n",
+ "df_avg"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Summary & Conclusion\n",
+ "\n",
+ "- The results indicate that sklearnex consistently produces the same results as scikit-learn (`result_match = True`).\n",
+ "- Performance improvements vary depending on input size, with a notable speedup as the number of rows increases.\n",
+ "- For large datasets (e.g., 10,000 rows), sklearnex achieves up to a **10x speedup** compared to scikit-learn.\n",
+ "- However, for smaller datasets (e.g., row size = 2), sklearnex sometimes shows **slower performance**."
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "venv310",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.11"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/sklearnex/metrics/pairwise.py b/sklearnex/metrics/pairwise.py
index ffcc136e1d..9821af87fc 100755
--- a/sklearnex/metrics/pairwise.py
+++ b/sklearnex/metrics/pairwise.py
@@ -14,7 +14,61 @@
# limitations under the License.
# ===============================================================================
+from sklearn.base import BaseEstimator
+from sklearn.metrics.pairwise import rbf_kernel as _sklearn_rbf_kernel
+
+from daal4py.sklearn._utils import sklearn_check_version
from daal4py.sklearn.metrics import pairwise_distances
from onedal._device_offload import support_input_format
+from onedal.primitives import rbf_kernel as _onedal_rbf_kernel
+
+from .._device_offload import dispatch
+from .._utils import PatchingConditionsChain
pairwise_distances = support_input_format(pairwise_distances)
+
+
+if sklearn_check_version("1.6"):
+ from sklearn.utils.validation import validate_data
+else:
+ validate_data = BaseEstimator._validate_data
+
+
+class RBFKernel:
+
+ def __init__(self):
+ pass
+
+ def _onedal_supported(self, method_name, *data):
+ patching_status = PatchingConditionsChain(
+ f"sklearn.metrics.pairwise.{method_name}"
+ )
+ return patching_status
+
+ def _onedal_cpu_supported(self, method_name, *data):
+ return self._onedal_supported(method_name, *data)
+
+ def _onedal_gpu_supported(self, method_name, *data):
+ return self._onedal_supported(method_name, *data)
+
+ def _onedal_rbf_kernel(self, X, Y=None, gamma=None, queue=None):
+ return _onedal_rbf_kernel(X, Y, gamma, queue)
+
+ def compute(self, X, Y=None, gamma=None):
+ result = dispatch(
+ self,
+ "rbf_kernel",
+ {
+ "onedal": self.__class__._onedal_rbf_kernel,
+ "sklearn": _sklearn_rbf_kernel,
+ },
+ X,
+ Y,
+ gamma,
+ )
+
+ return result
+
+
+def rbf_kernel(X, Y=None, gamma=None):
+ return RBFKernel().compute(X, Y, gamma)
diff --git a/sklearnex/metrics/tests/test_metrics.py b/sklearnex/metrics/tests/test_metrics.py
index 010a6ef0e1..baad105dc6 100755
--- a/sklearnex/metrics/tests/test_metrics.py
+++ b/sklearnex/metrics/tests/test_metrics.py
@@ -37,3 +37,16 @@ def test_sklearnex_import_pairwise_distances():
x = np.vstack([x, x])
res = pairwise_distances(x, metric="cosine")
assert_allclose(res, [[0.0, 0.0], [0.0, 0.0]], atol=1e-2)
+
+
+def test_sklearnex_import_rbf_kernel():
+ from sklearnex.metrics.pairwise import rbf_kernel
+
+ rng = np.random.RandomState(0)
+ X = rng.rand(5, 3)
+ gamma = 0.5
+ res = rbf_kernel(X, gamma=gamma)
+ expected_res = np.exp(
+ -gamma * np.sum((X[:, np.newaxis] - X[np.newaxis, :]) ** 2, axis=-1)
+ )
+ assert_allclose(res, expected_res, atol=1e-6)