|
16 | 16 | supportedCudaCapabilities = builtins.fromJSON ( |
17 | 17 | builtins.readFile ../build2cmake/src/cuda_supported_archs.json |
18 | 18 | ); |
19 | | - inherit (import ./torch-version-utils.nix { inherit lib; }) isCuda isMetal isRocm; |
| 19 | + inherit (import ./torch-version-utils.nix { inherit lib; }) |
| 20 | + isCuda |
| 21 | + isMetal |
| 22 | + isRocm |
| 23 | + isXpu |
| 24 | + ; |
20 | 25 | in |
21 | 26 | rec { |
22 | 27 | resolveDeps = import ./deps.nix { inherit lib; }; |
@@ -45,11 +50,13 @@ rec { |
45 | 50 | cuda = false; |
46 | 51 | metal = false; |
47 | 52 | rocm = false; |
| 53 | + xpu = false; |
48 | 54 | }; |
49 | 55 | in |
50 | 56 | lib.foldl (backends: kernel: backends // { ${kernelBackend kernel} = true; }) init kernels; |
51 | 57 |
|
52 | 58 | readBuildConfig = path: validateBuildConfig (readToml (path + "/build.toml")); |
| 59 | + tracedReadBuildConfig = path: readBuildConfig path; |
53 | 60 |
|
54 | 61 | srcFilter = |
55 | 62 | src: name: type: |
|
75 | 82 | (isCuda buildSet.buildConfig && backends'.cuda) |
76 | 83 | || (isRocm buildSet.buildConfig && backends'.rocm) |
77 | 84 | || (isMetal buildSet.buildConfig && backends'.metal) |
| 85 | + || (isXpu buildSet.buildConfig && backends'.xpu) |
78 | 86 | || (buildConfig.general.universal or false); |
79 | 87 | cudaVersionSupported = |
80 | 88 | !(isCuda buildSet.buildConfig) |
@@ -121,6 +129,8 @@ rec { |
121 | 129 | stdenv = |
122 | 130 | if pkgs.stdenv.hostPlatform.isDarwin then |
123 | 131 | pkgs.stdenv |
| 132 | + else if lib.any (k: k.backend == "xpu") (lib.attrValues buildConfig.kernel) then |
| 133 | + pkgs.stdenv |
124 | 134 | else if oldLinuxCompat then |
125 | 135 | pkgs.stdenvGlibc_2_27 |
126 | 136 | else |
@@ -234,7 +244,13 @@ rec { |
234 | 244 | let |
235 | 245 | pkgs = buildSet.pkgs; |
236 | 246 | rocmSupport = pkgs.config.rocmSupport or false; |
237 | | - stdenv = if rocmSupport then pkgs.stdenv else pkgs.cudaPackages.backendStdenv; |
| 247 | + stdenv = |
| 248 | + if rocmSupport then |
| 249 | + pkgs.stdenv |
| 250 | + else if isXpu buildSet.buildConfig then |
| 251 | + pkgs.stdenv |
| 252 | + else |
| 253 | + pkgs.cudaPackages.backendStdenv; |
238 | 254 | mkShell = pkgs.mkShell.override { inherit stdenv; }; |
239 | 255 | in |
240 | 256 | { |
@@ -274,7 +290,14 @@ rec { |
274 | 290 | let |
275 | 291 | pkgs = buildSet.pkgs; |
276 | 292 | rocmSupport = pkgs.config.rocmSupport or false; |
277 | | - stdenv = if rocmSupport then pkgs.stdenv else pkgs.cudaPackages.backendStdenv; |
| 293 | + xpuSupport = pkgs.config.xpuSupport or false; |
| 294 | + stdenv = |
| 295 | + if rocmSupport then |
| 296 | + pkgs.stdenv |
| 297 | + else if xpuSupport then |
| 298 | + pkgs.stdenv |
| 299 | + else |
| 300 | + pkgs.cudaPackages.backendStdenv; |
278 | 301 | mkShell = pkgs.mkShell.override { inherit stdenv; }; |
279 | 302 | in |
280 | 303 | { |
|
0 commit comments