Skip to content

Commit 1957032

Browse files
committed
add xpu support in nix
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
1 parent 08fcbf3 commit 1957032

File tree

8 files changed

+125
-10
lines changed

8 files changed

+125
-10
lines changed

build2cmake/src/templates/xpu/preamble.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ endif()
3333

3434
# Check for Intel XPU support in PyTorch
3535
run_python(XPU_AVAILABLE
36-
"import torch; print('true' if hasattr(torch, 'xpu') and torch.xpu.is_available() else 'false')"
36+
"import torch; print('true' if hasattr(torch, 'xpu') else 'false')"
3737
"Failed to check XPU availability")
3838

3939
if(NOT XPU_AVAILABLE STREQUAL "true")

lib/build-sets.nix

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ let
1515
isCuda
1616
isMetal
1717
isRocm
18+
isXpu
1819
;
1920

2021
# All build configurations supported by Torch.
@@ -39,6 +40,12 @@ let
3940
in
4041
lib.unique (builtins.map (torchVersion: torchVersion.rocmVersion) withRocm);
4142

43+
xpuVersions =
44+
let
45+
withXpu = builtins.filter (torchVersion: torchVersion ? xpuVersion) torchVersions;
46+
in
47+
lib.unique (builtins.map (torchVersion: torchVersion.xpuVersion) withXpu);
48+
4249
flattenVersion = version: lib.replaceStrings [ "." ] [ "_" ] (lib.versions.pad 2 version);
4350

4451
# An overlay that overides CUDA to the given version.
@@ -50,12 +57,16 @@ let
5057
rocmPackages = super."rocmPackages_${flattenVersion rocmVersion}";
5158
};
5259

60+
overlayForXpuVersion = xpuVersion: self: super: {
61+
xpuPackages = super."xpuPackages_${flattenVersion xpuVersion}";
62+
};
5363
# Construct the nixpkgs package set for the given versions.
5464
pkgsForVersions =
5565
buildConfig@{
5666
cudaVersion ? null,
5767
metal ? false,
5868
rocmVersion ? null,
69+
xpuVersion ? null,
5970
torchVersion,
6071
cxx11Abi,
6172
system,
@@ -69,11 +80,18 @@ let
6980
pkgsByRocmVer.${rocmVersion}
7081
else if isMetal buildConfig then
7182
pkgsForMetal
83+
else if isXpu buildConfig then
84+
pkgsByXpuVer.${xpuVersion}
7285
else
7386
throw "No compute framework set in Torch version";
74-
torch = pkgs.python3.pkgs."torch_${flattenVersion torchVersion}".override {
75-
inherit cxx11Abi;
76-
};
87+
torch = pkgs.python3.pkgs."torch_${flattenVersion torchVersion}".overrideAttrs (old: {
88+
passthru = (old.passthru or { }) // {
89+
xpuPackages = pkgs.xpuPackages or { };
90+
};
91+
cxx11Abi = cxx11Abi;
92+
xpuSupport =
93+
(pkgs ? xpuPackages) && ((builtins.length (builtins.attrNames pkgs.xpuPackages or { })) > 0);
94+
});
7795
in
7896
{
7997
inherit
@@ -83,6 +101,26 @@ let
83101
bundleBuild
84102
;
85103
};
104+
pkgsForXpuVersions =
105+
xpuVersions:
106+
builtins.listToAttrs (
107+
map (xpuVersion: {
108+
name = xpuVersion;
109+
value = import nixpkgs {
110+
inherit system;
111+
config = {
112+
allowUnfree = true;
113+
xpuSupport = true;
114+
};
115+
overlays = [
116+
hf-nix
117+
overlay
118+
(overlayForXpuVersion xpuVersion)
119+
];
120+
};
121+
}) xpuVersions
122+
);
123+
pkgsByXpuVer = pkgsForXpuVersions xpuVersions;
86124

87125
pkgsForMetal = import nixpkgs {
88126
inherit system;

lib/build-variants.nix

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ let
55
isCuda
66
isMetal
77
isRocm
8+
isXpu
89
;
910
in
1011
rec {
@@ -16,8 +17,10 @@ rec {
1617
"metal"
1718
else if buildConfig ? "rocmVersion" then
1819
"rocm"
20+
else if buildConfig ? xpuVersion then
21+
"xpu"
1922
else
20-
throw "Could not find compute framework: no CUDA or ROCm version specified and Metal is not enabled";
23+
throw "Could not find compute framework: no CUDA, ROCm, XPU version specified and Metal is not enabled";
2124

2225
# Build variants included in bundle builds.
2326
buildVariants =
@@ -31,6 +34,8 @@ rec {
3134
"rocm${flattenVersion (lib.versions.majorMinor version.rocmVersion)}"
3235
else if isMetal version then
3336
"metal"
37+
else if isXpu version then
38+
"xpu${flattenVersion (lib.versions.majorMinor version.xpuVersion)}"
3439
else
3540
throw "No compute framework set in Torch version";
3641
buildName =

lib/build-version.nix

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,24 @@ let
1212
cudaVersion = torch: "cu${flattenVersion torch.cudaPackages.cudaMajorMinorVersion}";
1313
rocmVersion =
1414
torch: "rocm${flattenVersion (lib.versions.majorMinor torch.rocmPackages.rocm.version)}";
15-
gpuVersion = torch: (if torch.cudaSupport then cudaVersion else rocmVersion) torch;
1615
torchVersion = torch: flattenVersion torch.version;
16+
xpuVersion =
17+
torch:
18+
"xpu${flattenVersion (lib.versions.majorMinor torch.passthru.xpuPackages.intel-oneapi-dpcpp-cpp.version)}";
19+
gpuVersion =
20+
torch:
21+
if torch.cudaSupport then
22+
cudaVersion torch
23+
else if (torch ? rocmPackages) && (torch.rocmSupport or false) then
24+
rocmVersion torch
25+
else if (torch ? passthru) && (torch.passthru ? xpuPackages) && (torch.xpuSupport or false) then
26+
xpuVersion torch
27+
else
28+
null;
1729
in
1830
if pkgs.stdenv.hostPlatform.isDarwin then
1931
"torch${torchVersion torch}-metal-${targetPlatform}"
20-
else
32+
else if gpuVersion torch != null then
2133
"torch${torchVersion torch}-${abi torch}-${gpuVersion torch}-${targetPlatform}"
34+
else
35+
throw "No supported GPU framework (CUDA, ROCm, XPU, Metal) detected for build-version.nix"

lib/build.nix

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@ let
1616
supportedCudaCapabilities = builtins.fromJSON (
1717
builtins.readFile ../build2cmake/src/cuda_supported_archs.json
1818
);
19-
inherit (import ./torch-version-utils.nix { inherit lib; }) isCuda isMetal isRocm;
19+
inherit (import ./torch-version-utils.nix { inherit lib; })
20+
isCuda
21+
isMetal
22+
isRocm
23+
isXpu
24+
;
2025
in
2126
rec {
2227
resolveDeps = import ./deps.nix { inherit lib; };
@@ -45,11 +50,13 @@ rec {
4550
cuda = false;
4651
metal = false;
4752
rocm = false;
53+
xpu = false;
4854
};
4955
in
5056
lib.foldl (backends: kernel: backends // { ${kernelBackend kernel} = true; }) init kernels;
5157

5258
readBuildConfig = path: validateBuildConfig (readToml (path + "/build.toml"));
59+
tracedReadBuildConfig = path: readBuildConfig path;
5360

5461
srcFilter =
5562
src: name: type:
@@ -75,6 +82,7 @@ rec {
7582
(isCuda buildSet.buildConfig && backends'.cuda)
7683
|| (isRocm buildSet.buildConfig && backends'.rocm)
7784
|| (isMetal buildSet.buildConfig && backends'.metal)
85+
|| (isXpu buildSet.buildConfig && backends'.xpu)
7886
|| (buildConfig.general.universal or false);
7987
cudaVersionSupported =
8088
!(isCuda buildSet.buildConfig)
@@ -121,6 +129,8 @@ rec {
121129
stdenv =
122130
if pkgs.stdenv.hostPlatform.isDarwin then
123131
pkgs.stdenv
132+
else if lib.any (k: k.backend == "xpu") (lib.attrValues buildConfig.kernel) then
133+
pkgs.stdenv
124134
else if oldLinuxCompat then
125135
pkgs.stdenvGlibc_2_27
126136
else
@@ -234,7 +244,13 @@ rec {
234244
let
235245
pkgs = buildSet.pkgs;
236246
rocmSupport = pkgs.config.rocmSupport or false;
237-
stdenv = if rocmSupport then pkgs.stdenv else pkgs.cudaPackages.backendStdenv;
247+
stdenv =
248+
if rocmSupport then
249+
pkgs.stdenv
250+
else if isXpu buildSet.buildConfig then
251+
pkgs.stdenv
252+
else
253+
pkgs.cudaPackages.backendStdenv;
238254
mkShell = pkgs.mkShell.override { inherit stdenv; };
239255
in
240256
{
@@ -274,7 +290,14 @@ rec {
274290
let
275291
pkgs = buildSet.pkgs;
276292
rocmSupport = pkgs.config.rocmSupport or false;
277-
stdenv = if rocmSupport then pkgs.stdenv else pkgs.cudaPackages.backendStdenv;
293+
xpuSupport = pkgs.config.xpuSupport or false;
294+
stdenv =
295+
if rocmSupport then
296+
pkgs.stdenv
297+
else if xpuSupport then
298+
pkgs.stdenv
299+
else
300+
pkgs.cudaPackages.backendStdenv;
278301
mkShell = pkgs.mkShell.override { inherit stdenv; };
279302
in
280303
{

lib/torch-extension/default.nix

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
config,
1515
cudaSupport ? config.cudaSupport,
1616
rocmSupport ? config.rocmSupport,
17+
xpuSupport ? config.xpuSupport,
1718

1819
lib,
1920
stdenv,
@@ -70,6 +71,8 @@ stdenv.mkDerivation (prevAttrs: {
7071
"cuda"
7172
else if rocmSupport then
7273
"rocm"
74+
else if xpuSupport then
75+
"xpu"
7376
else
7477
"metal"
7578
} --ops-id ${rev} build.toml
@@ -98,6 +101,13 @@ stdenv.mkDerivation (prevAttrs: {
98101
++ lib.optionals rocmSupport [
99102
clr
100103
]
104+
++ lib.optionals xpuSupport (
105+
with torch.passthru.xpuPackages;
106+
[
107+
ocloc
108+
oneapi-torch-dev
109+
]
110+
)
101111
++ lib.optionals stdenv.hostPlatform.isDarwin [
102112
rewrite-nix-paths-macho
103113
];
@@ -120,6 +130,12 @@ stdenv.mkDerivation (prevAttrs: {
120130
]
121131
)
122132
++ lib.optionals rocmSupport (with rocmPackages; [ hipsparselt ])
133+
++ lib.optionals xpuSupport (
134+
with torch.passthru.xpuPackages;
135+
[
136+
oneapi-torch-dev
137+
]
138+
)
123139
++ lib.optionals stdenv.hostPlatform.isDarwin [
124140
apple-sdk_15
125141
]
@@ -136,6 +152,10 @@ stdenv.mkDerivation (prevAttrs: {
136152
}
137153
// lib.optionalAttrs rocmSupport {
138154
PYTORCH_ROCM_ARCH = lib.concatStringsSep ";" torch.rocmArchs;
155+
}
156+
// lib.optionalAttrs xpuSupport {
157+
MKLROOT = torch.passthru.xpuPackages.oneapi-torch-dev;
158+
SYCL_ROOT = torch.passthru.xpuPackages.oneapi-torch-dev;
139159
};
140160

141161
# If we use the default setup, CMAKE_CUDA_HOST_COMPILER gets set to nixpkgs g++.

lib/torch-version-utils.nix

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@
1010
isCuda = version: version ? cudaVersion;
1111
isMetal = version: version.metal or false;
1212
isRocm = version: version ? rocmVersion;
13+
isXpu = version: version ? xpuVersion;
1314
}

versions.nix

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,20 @@
4040
bundleBuild = false;
4141
}
4242

43+
{
44+
torchVersion = "2.7";
45+
xpuVersion = "2025.0.2";
46+
cxx11Abi = true;
47+
systems = [ "x86_64-linux" ];
48+
bundleBuild = true;
49+
}
50+
{
51+
torchVersion = "2.8";
52+
xpuVersion = "2025.1.3";
53+
cxx11Abi = true;
54+
systems = [ "x86_64-linux" ];
55+
bundleBuild = true;
56+
}
4357
{
4458
torchVersion = "2.7";
4559
cxx11Abi = true;

0 commit comments

Comments
 (0)