Skip to content

Commit 0cf2feb

Browse files
authored
Add ci output for generated flakes (huggingface#254)
This output is like `bundle`, but only builds one variant for each framework.
1 parent a48cbd1 commit 0cf2feb

File tree

3 files changed

+73
-43
lines changed

3 files changed

+73
-43
lines changed

flake.nix

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
system: buildSet:
4747
import lib/build.nix {
4848
inherit (nixpkgs) lib;
49-
buildSets = buildSetPerSystem.${system};
5049
}
5150
) buildSetPerSystem;
5251

@@ -104,6 +103,7 @@
104103
pythonNativeCheckInputs
105104
;
106105
build = buildPerSystem.${system};
106+
buildSets = buildSetPerSystem.${system};
107107
}
108108
);
109109
}

lib/build.nix

Lines changed: 30 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
{
22
lib,
33

4-
# List of build sets. Each build set is a attrset of the form
5-
#
6-
# { pkgs = <nixpkgs>, torch = <torch drv> }
7-
#
8-
# The Torch derivation is built as-is. So e.g. the ABI version should
9-
# already be set.
10-
buildSets,
4+
# Every `buildSets` argument is a list of build sets. Each build set is
5+
# a attrset of the form
6+
#
7+
# { pkgs = <nixpkgs>, torch = <torch drv> }
8+
#
9+
# The Torch derivation is built as-is. So e.g. the ABI version should
10+
# already be set.
1111
}:
1212

1313
let
@@ -106,10 +106,11 @@ rec {
106106
in
107107
builtins.filter supportedBuildSet buildSets;
108108

109-
applicableBuildSets = path: filterApplicableBuildSets (readBuildConfig path) buildSets;
109+
applicableBuildSets =
110+
{ path, buildSets }: filterApplicableBuildSets (readBuildConfig path) buildSets;
110111

111112
# Build a single Torch extension.
112-
buildTorchExtension =
113+
mkTorchExtension =
113114
{
114115
buildConfig,
115116
pkgs,
@@ -172,56 +173,47 @@ rec {
172173
});
173174

174175
# Build multiple Torch extensions.
175-
buildNixTorchExtensions =
176-
{ path, rev }:
177-
let
178-
extensionForTorch =
179-
{ path, rev }:
180-
buildSet: {
181-
name = torchBuildVersion buildSet;
182-
value = buildTorchExtension buildSet { inherit path rev; };
183-
};
184-
in
185-
builtins.listToAttrs (lib.map (extensionForTorch { inherit path rev; }) (applicableBuildSets path));
186-
187-
# Build multiple Torch extensions.
188-
buildDistTorchExtensions =
176+
mkDistTorchExtensions =
189177
{
190178
path,
191179
rev,
192180
doGetKernelCheck,
193181
bundleOnly,
182+
buildSets,
194183
}:
195184
let
196185
extensionForTorch =
197186
{ path, rev }:
198187
buildSet: {
199188
name = torchBuildVersion buildSet;
200-
value = buildTorchExtension buildSet {
189+
value = mkTorchExtension buildSet {
201190
inherit path rev doGetKernelCheck;
202191
stripRPath = true;
203192
oldLinuxCompat = true;
204193
};
205194
};
206195
applicableBuildSets' =
207-
if bundleOnly then
208-
builtins.filter (buildSet: buildSet.bundleBuild) (applicableBuildSets path)
209-
else
210-
(applicableBuildSets path);
196+
if bundleOnly then builtins.filter (buildSet: buildSet.bundleBuild) buildSets else buildSets;
211197
in
212198
builtins.listToAttrs (lib.map (extensionForTorch { inherit path rev; }) applicableBuildSets');
213199

214-
buildTorchExtensionBundle =
200+
mkTorchExtensionBundle =
215201
{
216202
path,
217203
rev,
218204
doGetKernelCheck,
205+
buildSets,
219206
}:
220207
let
221208
# We just need to get any nixpkgs for use by the path join.
222209
pkgs = (builtins.head buildSets).pkgs;
223-
extensions = buildDistTorchExtensions {
224-
inherit path rev doGetKernelCheck;
210+
extensions = mkDistTorchExtensions {
211+
inherit
212+
buildSets
213+
path
214+
rev
215+
doGetKernelCheck
216+
;
225217
bundleOnly = true;
226218
};
227219
buildConfig = readBuildConfig path;
@@ -243,6 +235,7 @@ rec {
243235
{
244236
path,
245237
rev,
238+
buildSets,
246239
doGetKernelCheck,
247240
pythonCheckInputs,
248241
pythonNativeCheckInputs,
@@ -271,18 +264,19 @@ rec {
271264
++ (pythonCheckInputs python3.pkgs);
272265
shellHook = ''
273266
export PYTHONPATH=''${PYTHONPATH}:${
274-
buildTorchExtension buildSet { inherit path rev doGetKernelCheck; }
267+
mkTorchExtension buildSet { inherit path rev doGetKernelCheck; }
275268
}
276269
'';
277270
};
278271
};
279272
in
280-
builtins.listToAttrs (lib.map (shellForBuildSet { inherit path rev; }) (applicableBuildSets path));
273+
builtins.listToAttrs (lib.map (shellForBuildSet { inherit path rev; }) buildSets);
281274

282-
torchDevShells =
275+
mkTorchDevShells =
283276
{
284277
path,
285278
rev,
279+
buildSets,
286280
doGetKernelCheck,
287281
pythonCheckInputs,
288282
pythonNativeCheckInputs,
@@ -309,7 +303,7 @@ rec {
309303
]
310304
++ (pythonNativeCheckInputs python3.pkgs);
311305
buildInputs = with pkgs; [ python3.pkgs.pytest ] ++ (pythonCheckInputs python3.pkgs);
312-
inputsFrom = [ (buildTorchExtension buildSet { inherit path rev doGetKernelCheck; }) ];
306+
inputsFrom = [ (mkTorchExtension buildSet { inherit path rev doGetKernelCheck; }) ];
313307
env = lib.optionalAttrs rocmSupport {
314308
PYTORCH_ROCM_ARCH = lib.concatStringsSep ";" buildSet.torch.rocmArchs;
315309
HIP_PATH = pkgs.rocmPackages.clr;
@@ -318,5 +312,5 @@ rec {
318312
};
319313
};
320314
in
321-
builtins.listToAttrs (lib.map shellForBuildSet (applicableBuildSets path));
315+
builtins.listToAttrs (lib.map shellForBuildSet buildSets);
322316
}

lib/gen-flake-outputs.nix

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
runCommand,
88

99
path,
10+
buildSets,
1011
rev ? null,
1112
self ? null,
1213

@@ -37,6 +38,8 @@ let
3738

3839
revUnderscored = builtins.replaceStrings [ "-" ] [ "_" ] flakeRev;
3940

41+
applicableBuildSets = build.applicableBuildSets { inherit path buildSets; };
42+
4043
# For picking a default shell, etc. we want to use the following logic:
4144
#
4245
# - Prefer bundle builds over non-bundle builds.
@@ -46,7 +49,7 @@ let
4649

4750
# Enrich the build configs with generic attributes for framework
4851
# order/version. Also make bundleBuild attr explicit.
49-
buildSets = map (
52+
addSortOrder = map (
5053
set:
5154
let
5255
inherit (set) buildConfig;
@@ -57,12 +60,23 @@ let
5760

5861
buildConfig // {
5962
bundleBuild = buildConfig.bundleBuild or false;
63+
framework =
64+
if buildConfig ? cudaVersion then
65+
"cuda"
66+
else if buildConfig ? rocmVersion then
67+
"rocm"
68+
else if buildConfig ? xpuVersion then
69+
"xpu"
70+
else if system == "aarch64-darwin" then
71+
"metal"
72+
else
73+
throw "Cannot determine framework for build set";
6074
frameworkOrder = if buildConfig ? cudaVersion then 0 else 1;
6175
frameworkVersion =
6276
buildConfig.cudaVersion or buildConfig.rocmVersion or buildConfig.xpuVersion or "0.0";
6377
};
6478
}
65-
) (build.applicableBuildSets path);
79+
);
6680
configCompare =
6781
setA: setB:
6882
let
@@ -77,25 +91,27 @@ let
7791
builtins.compareVersions a.torchVersion b.torchVersion > 0
7892
else
7993
builtins.compareVersions a.frameworkVersion b.frameworkVersion < 0;
80-
buildSetsSorted = lib.sort configCompare buildSets;
94+
buildSetsSorted = lib.sort configCompare (addSortOrder applicableBuildSets);
8195
bestBuildSet =
8296
if buildSetsSorted == [ ] then
8397
throw "No build variant is compatible with this system"
8498
else
8599
builtins.head buildSetsSorted;
86100
shellTorch = buildName bestBuildSet.buildConfig;
101+
headOrEmpty = l: if l == [ ] then [ ] else [ (builtins.head l) ];
87102
in
88103
{
89104
devShells = rec {
90105
default = devShells.${shellTorch};
91106
test = testShells.${shellTorch};
92-
devShells = build.torchDevShells {
107+
devShells = build.mkTorchDevShells {
93108
inherit
94109
path
95110
doGetKernelCheck
96111
pythonCheckInputs
97112
pythonNativeCheckInputs
98113
;
114+
buildSets = applicableBuildSets;
99115
rev = revUnderscored;
100116
};
101117
testShells = build.torchExtensionShells {
@@ -105,13 +121,15 @@ in
105121
pythonCheckInputs
106122
pythonNativeCheckInputs
107123
;
124+
buildSets = applicableBuildSets;
108125
rev = revUnderscored;
109126
};
110127
};
111128
packages =
112129
let
113-
bundle = build.buildTorchExtensionBundle {
130+
bundle = build.mkTorchExtensionBundle {
114131
inherit path doGetKernelCheck;
132+
buildSets = applicableBuildSets;
115133
rev = revUnderscored;
116134
};
117135
in
@@ -140,6 +158,23 @@ in
140158
chmod -R +w build
141159
'';
142160

161+
ci =
162+
let
163+
setsWithFramework =
164+
framework: builtins.filter (set: set.buildConfig.framework == framework) buildSetsSorted;
165+
# It is too costly to build all variants in CI, so we just build one per framework.
166+
onePerFramework =
167+
(headOrEmpty (setsWithFramework "cuda"))
168+
++ (headOrEmpty (setsWithFramework "metal"))
169+
++ (headOrEmpty (setsWithFramework "rocm"))
170+
++ (headOrEmpty (setsWithFramework "xpu"));
171+
in
172+
build.mkTorchExtensionBundle {
173+
inherit path doGetKernelCheck;
174+
buildSets = onePerFramework;
175+
rev = revUnderscored;
176+
};
177+
143178
kernels =
144179
bestBuildSet.pkgs.python3.withPackages (
145180
ps: with ps; [
@@ -151,10 +186,11 @@ in
151186
meta.mainProgram = "kernels";
152187
};
153188

154-
redistributable = build.buildDistTorchExtensions {
189+
redistributable = build.mkDistTorchExtensions {
155190
inherit path doGetKernelCheck;
156191
bundleOnly = false;
157192
rev = revUnderscored;
193+
buildSets = applicableBuildSets;
158194
};
159195
};
160196
}

0 commit comments

Comments
 (0)