Skip to content

Commit 2f51d2a

Browse files
committed
add tutorial for integrate custom op using sycl
1 parent 4dee820 commit 2f51d2a

File tree

3 files changed

+281
-0
lines changed

3 files changed

+281
-0
lines changed
Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
.. _cpp-custom-ops-tutorial-sycl:
2+
3+
Custom SYCL Operators
4+
=====================
5+
6+
.. grid:: 2
7+
8+
.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
9+
:class-card: card-prerequisites
10+
11+
* How to integrate custom operators written in SYCL with PyTorch
12+
13+
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
14+
:class-card: card-prerequisites
15+
16+
* PyTorch 2.8 or later
17+
* Basic understanding of SYCL programming
18+
19+
.. note::
20+
21+
``SYCL`` serves as the backend programming language for Intel GPUs (device label ``xpu``). For configuration details, see:
22+
`Getting Started on Intel GPUs <https://docs.pytorch.org/docs/main/notes/get_start_xpu.html>`_. The Intel Compiler, which comes bundled with Intel Deep Learning Essentials, handles ``SYCL`` compilation. Ensure you install and activate the compiler environment prior to executing the code examples in this tutorial.
23+
24+
PyTorch offers a large library of operators that work on Tensors (e.g. torch.add, torch.sum, etc).
25+
However, you may wish to bring a new custom operator to PyTorch. This tutorial demonstrates the
26+
blessed path to authoring a custom operator written in SYCL. Tutorials for C++ and CUDA operators are available in the :ref:`cpp-custom-ops-tutorial`.
27+
28+
Follow the structure to create a custom SYCL operator:
29+
30+
.. code-block:: text
31+
32+
sycl_example/
33+
├── setup.py
34+
├── sycl_extension
35+
│ ├── __init__.py
36+
│ ├── muladd.sycl
37+
│ └── ops.py
38+
└── test_sycl_extension.py
39+
40+
Setting up the Build System
41+
---------------------------
42+
43+
If you need to compile **SYCL** code (for example, ``.sycl`` files), use `torch.utils.cpp_extension.SyclExtension <https://docs.pytorch.org/docs/stable/cpp_extension.html#torch.utils.cpp_extension.SyclExtension>`_.
44+
The setup process is very similar to C++/CUDA, except the compilation arguments need to be adjusted for SYCL.
45+
46+
Using ``sycl_extension`` is as simple as writing the following ``setup.py``:
47+
48+
.. code-block:: python
49+
50+
import os
51+
import torch
52+
import glob
53+
from setuptools import find_packages, setup
54+
from torch.utils.cpp_extension import SyclExtension, BuildExtension
55+
56+
library_name = "sycl_extension"
57+
py_limited_api = True
58+
extra_compile_args = {
59+
"cxx": ["-O3",
60+
"-fdiagnostics-color=always",
61+
"-DPy_LIMITED_API=0x03090000"],
62+
"sycl": ["-O3" ]
63+
}
64+
65+
assert(torch.xpu.is_available()), "XPU is not available, please check your environment"
66+
# Source files collection
67+
this_dir = os.path.dirname(os.path.curdir)
68+
extensions_dir = os.path.join(this_dir, library_name)
69+
sources = list(glob.glob(os.path.join(extensions_dir, "*.sycl")))
70+
# Construct extension
71+
ext_modules = [
72+
SyclExtension(
73+
f"{library_name}._C",
74+
sources,
75+
extra_compile_args=extra_compile_args,
76+
py_limited_api=py_limited_api,
77+
)
78+
]
79+
setup(
80+
name=library_name,
81+
packages=find_packages(),
82+
ext_modules=ext_modules,
83+
install_requires=["torch"],
84+
description="Simple Example of PyTorch Sycl extensions",
85+
cmdclass={"build_ext": BuildExtension},
86+
options={"bdist_wheel": {"py_limited_api": "cp39"}} if py_limited_api else {},
87+
)
88+
89+
90+
Defining the custom op and adding backend implementations
91+
---------------------------------------------------------
92+
First, let's write a Sycl function that computes ``mymuladd``:
93+
94+
In order to use this from PyTorch’s Python frontend, we need to register it
95+
as a PyTorch operator using the ``TORCH_LIBRARY`` API. This will automatically
96+
bind the operator to Python.
97+
98+
99+
If you also have a SYCL implementation of ``myaddmul``, you can also register it
100+
in a separate ``TORCH_LIBRARY_IMPL`` block:
101+
102+
.. code-block:: cpp
103+
104+
// Copyright (c) 2025 Intel Corporation
105+
106+
#include <c10/xpu/XPUStream.h>
107+
#include <sycl/sycl.hpp>
108+
#include <ATen/Operators.h>
109+
#include <torch/all.h>
110+
#include <torch/library.h>
111+
112+
namespace sycl_extension {
113+
// MulAdd Kernel: result = a * b + c
114+
static void muladd_kernel(
115+
int numel, const float* a, const float* b, float c, float* result,
116+
const sycl::nd_item<1>& item) {
117+
int idx = item.get_global_id(0);
118+
if (idx < numel) {
119+
result[idx] = a[idx] * b[idx] + c;
120+
}
121+
}
122+
123+
class MulAddKernelFunctor {
124+
public:
125+
MulAddKernelFunctor(int _numel, const float* _a, const float* _b, float _c, float* _result)
126+
: numel(_numel), a(_a), b(_b), c(_c), result(_result) {}
127+
void operator()(const sycl::nd_item<1>& item) const {
128+
muladd_kernel(numel, a, b, c, result, item);
129+
}
130+
131+
private:
132+
int numel;
133+
const float* a;
134+
const float* b;
135+
float c;
136+
float* result;
137+
};
138+
139+
at::Tensor mymuladd_xpu(const at::Tensor& a, const at::Tensor& b, double c) {
140+
TORCH_CHECK(a.sizes() == b.sizes(), "a and b must have the same shape");
141+
TORCH_CHECK(a.dtype() == at::kFloat, "a must be a float tensor");
142+
TORCH_CHECK(b.dtype() == at::kFloat, "b must be a float tensor");
143+
TORCH_CHECK(a.device().is_xpu(), "a must be an XPU tensor");
144+
TORCH_CHECK(b.device().is_xpu(), "b must be an XPU tensor");
145+
146+
at::Tensor a_contig = a.contiguous();
147+
at::Tensor b_contig = b.contiguous();
148+
at::Tensor result = at::empty_like(a_contig);
149+
150+
const float* a_ptr = a_contig.data_ptr<float>();
151+
const float* b_ptr = b_contig.data_ptr<float>();
152+
float* res_ptr = result.data_ptr<float>();
153+
int numel = a_contig.numel();
154+
155+
sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue();
156+
constexpr int threads = 256;
157+
int blocks = (numel + threads - 1) / threads;
158+
159+
queue.submit([&](sycl::handler& cgh) {
160+
cgh.parallel_for<MulAddKernelFunctor>(
161+
sycl::nd_range<1>(blocks * threads, threads),
162+
MulAddKernelFunctor(numel, a_ptr, b_ptr, static_cast<float>(c), res_ptr)
163+
);
164+
});
165+
166+
return result;
167+
}
168+
// Defines the operators
169+
TORCH_LIBRARY(sycl_extension, m) {
170+
m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor");
171+
}
172+
173+
// ==================================================
174+
// Register Sycl Implementations to Torch Library
175+
// ==================================================
176+
TORCH_LIBRARY_IMPL(sycl_extension, XPU, m) {
177+
m.impl("mymuladd", &mymuladd_xpu);
178+
}
179+
180+
} // namespace sycl_extension
181+
182+
183+
184+
Create a Python Interface
185+
-------------------------
186+
187+
Create a Python interface for our operator in the ``sycl_extension/ops.py`` file:
188+
189+
.. code-block:: python
190+
191+
import torch
192+
from torch import Tensor
193+
__all__ = ["mymuladd"]
194+
195+
def mymuladd(a: Tensor, b: Tensor, c: float) -> Tensor:
196+
"""Performs a * b + c in an efficient fused kernel"""
197+
return torch.ops.sycl_extension.mymuladd.default(a, b, c)
198+
199+
Initialize Package
200+
------------------
201+
202+
Create ``sycl_extension/__init__.py`` file to make the package importable:
203+
204+
.. code-block:: python
205+
206+
import ctypes
207+
from pathlib import Path
208+
209+
import torch
210+
211+
current_dir = Path(__file__).parent.parent
212+
build_dir = current_dir / "build"
213+
so_files = list(build_dir.glob("**/*.so"))
214+
215+
assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}"
216+
217+
with torch._ops.dl_open_guard():
218+
loaded_lib = ctypes.CDLL(so_files[0])
219+
220+
from . import ops
221+
222+
__all__ = [
223+
"loaded_lib",
224+
"ops",
225+
]
226+
227+
Testing sycl extension operator
228+
-------------------
229+
230+
Use simple test to verify that the operator works correctly.
231+
232+
.. code-block:: python
233+
234+
import torch
235+
from torch.testing._internal.common_utils import TestCase
236+
import unittest
237+
import sycl_extension
238+
239+
def reference_muladd(a, b, c):
240+
return a * b + c
241+
242+
class TestMyMulAdd(TestCase):
243+
def sample_inputs(self, device, *, requires_grad=False):
244+
def make_tensor(*size):
245+
return torch.randn(size, device=device, requires_grad=requires_grad)
246+
247+
def make_nondiff_tensor(*size):
248+
return torch.randn(size, device=device, requires_grad=False)
249+
250+
return [
251+
[make_tensor(3), make_tensor(3), 1],
252+
[make_tensor(20), make_tensor(20), 3.14],
253+
[make_tensor(20), make_nondiff_tensor(20), -123],
254+
[make_nondiff_tensor(2, 3), make_tensor(2, 3), -0.3],
255+
]
256+
257+
def _test_correctness(self, device):
258+
samples = self.sample_inputs(device)
259+
for args in samples:
260+
result = sycl_extension.ops.mymuladd(*args)
261+
expected = reference_muladd(*args)
262+
torch.testing.assert_close(result, expected)
263+
264+
@unittest.skipIf(not torch.xpu.is_available(), "requires Intel GPU")
265+
def test_correctness_xpu(self):
266+
self._test_correctness("xpu")
267+
268+
if __name__ == "__main__":
269+
unittest.main()
270+
271+
This test checks the correctness of the custom operator by comparing its output against a reference implementation.
272+
273+
Conclusion
274+
----------
275+
276+
In this tutorial, we demonstrated how to implement and compile custom SYCL operators for PyTorch. We specifically showcased an inference operation ``muladd``. For adding backward support or enabling torch.compile compatibility, please refer to :ref:`cpp-custom-ops-tutorial`.

advanced_source/custom_ops_landing_page.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ Integrating custom C++ and/or CUDA code with PyTorch
3030

3131
Please see :ref:`cpp-custom-ops-tutorial`.
3232

33+
.. note::
34+
35+
``SYCL`` serves as the backend programming language for Intel GPUs. Integrate custom Sycl code refer to :ref:`cpp-custom-ops-tutorial-sycl`.
36+
3337
You may wish to author a custom operator from C++ (as opposed to Python) if:
3438

3539
- you have custom C++ and/or CUDA code.

index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -977,6 +977,7 @@ Additional Resources
977977
advanced/custom_ops_landing_page
978978
advanced/python_custom_ops
979979
advanced/cpp_custom_ops
980+
advanced/cpp_custom_ops_sycl
980981
intermediate/custom_function_double_backward_tutorial
981982
intermediate/custom_function_conv_bn_tutorial
982983
advanced/cpp_extension

0 commit comments

Comments
 (0)