Skip to content

Commit c74ddff

Browse files
authored
Backport kernels changes for FA4 support (#383)
1 parent c9da101 commit c74ddff

4 files changed

Lines changed: 54 additions & 18 deletions

File tree

kernels/src/kernels/deps.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
)
1212

1313

14-
def validate_dependencies(dependencies: list[str], backend: str):
14+
def validate_dependencies(
15+
kernel_module_name: str, dependencies: list[str], backend: str
16+
):
1517
"""
1618
Validate a list of dependencies to ensure they are installed.
1719
@@ -31,13 +33,27 @@ def validate_dependencies(dependencies: list[str], backend: str):
3133
python_packages = backend_deps[dependency].get("python", [])
3234
else:
3335
# Dependency not found in general or backend-specific dependencies
34-
raise ValueError(f"Invalid dependency: {dependency}")
36+
raise ValueError(
37+
f"Kernel module `{kernel_module_name}` uses unsupported kernel dependency: {dependency}"
38+
)
3539

3640
# Check if each python package is installed
3741
for python_package in python_packages:
3842
# Convert package name to module name (replace - with _)
39-
module_name = python_package.replace("-", "_")
43+
pkg_name = python_package.get("pkg")
44+
# Assertion because this should not happen and is a bug.
45+
assert (
46+
pkg_name is not None
47+
), f"Invalid dependency data for `{dependency}`: missing `pkg` field."
48+
49+
module_name = python_package.get("import")
50+
if module_name is None:
51+
# These are typically packages that do not provide any Python
52+
# code, but get installed to Python's library dirctory. E.g.
53+
# OneAPI.
54+
continue
55+
4056
if importlib.util.find_spec(module_name) is None:
4157
raise ImportError(
42-
f"Kernel requires Python dependency `{python_package}`. Please install with: pip install {python_package}"
58+
f"Kernel module `{kernel_module_name}` requires Python dependency `{pkg_name}`. Please install with: pip install {pkg_name}"
4359
)

kernels/src/kernels/python_depends.json

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,33 @@
22
"general": {
33
"einops": {
44
"nix": ["einops"],
5-
"python": ["einops"]
5+
"python": [{ "pkg": "einops", "import": "einops" }]
6+
},
7+
"tvm-ffi": {
8+
"nix": ["tvm-ffi"],
9+
"python": [{ "pkg": "apache-tvm-ffi", "import": "tvm_ffi" }]
610
}
711
},
812
"backends": {
913
"cpu": {},
1014
"cuda": {
1115
"nvidia-cutlass-dsl": {
1216
"nix": ["nvidia-cutlass-dsl"],
13-
"python": ["nvidia-cutlass-dsl"]
17+
"python": [{ "pkg": "nvidia-cutlass-dsl", "import": "cutlass" }]
1418
}
1519
},
1620
"metal": {},
1721
"neuron": {
1822
"nki": {
1923
"nix": [],
20-
"python": ["nki"]
24+
"python": [{ "pkg": "nki", "import": "nki" }]
2125
}
2226
},
2327
"rocm": {},
2428
"xpu": {
2529
"onednn": {
2630
"nix": [],
27-
"python": ["onednn-devel"]
31+
"python": [{ "pkg": "onednn-devel" }]
2832
}
2933
}
3034
}

kernels/src/kernels/utils.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ def _get_cache_dir() -> str | None:
2828
"""Returns the kernels cache directory."""
2929
cache_dir = os.environ.get("HF_KERNELS_CACHE", None)
3030
if cache_dir is not None:
31-
logging.warning("HF_KERNELS_CACHE will be removed in the future, use KERNELS_CACHE instead")
31+
logging.warning(
32+
"HF_KERNELS_CACHE will be removed in the future, use KERNELS_CACHE instead"
33+
)
3234
return cache_dir
3335

3436
return os.environ.get("KERNELS_CACHE", None)
@@ -136,7 +138,7 @@ def build_variants() -> list[str]:
136138

137139
def _import_from_path(module_name: str, variant_path: Path) -> ModuleType:
138140
metadata = Metadata.load_from_variant(variant_path)
139-
validate_dependencies(metadata.python_depends, backend())
141+
validate_dependencies(module_name, metadata.python_depends, backend())
140142

141143
file_path = variant_path / "__init__.py"
142144
if not file_path.exists():
@@ -203,7 +205,9 @@ def install_kernel(
203205
try:
204206
return _find_kernel_in_repo_path(repo_path, package_name, variant_locks)
205207
except FileNotFoundError:
206-
raise FileNotFoundError(f"Cannot install kernel from repo {repo_id} (revision: {revision})")
208+
raise FileNotFoundError(
209+
f"Cannot install kernel from repo {repo_id} (revision: {revision})"
210+
)
207211

208212

209213
def _find_kernel_in_repo_path(
@@ -268,7 +272,9 @@ def install_kernel_all_variants(
268272
if variant_lock is None:
269273
raise ValueError(f"No lock found for build variant: {variant}")
270274

271-
validate_kernel(repo_path=repo_path, variant=variant, hash=variant_lock.hash)
275+
validate_kernel(
276+
repo_path=repo_path, variant=variant, hash=variant_lock.hash
277+
)
272278

273279
return repo_path / "build"
274280

@@ -311,7 +317,9 @@ def get_kernel(
311317
```
312318
"""
313319
revision = select_revision_or_version(repo_id, revision=revision, version=version)
314-
package_name, variant_path = install_kernel(repo_id, revision=revision, user_agent=user_agent)
320+
package_name, variant_path = install_kernel(
321+
repo_id, revision=revision, user_agent=user_agent
322+
)
315323
return _import_from_path(package_name, variant_path)
316324

317325

@@ -344,7 +352,9 @@ def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType:
344352
raise FileNotFoundError(f"Could not find package '{package_name}' in {repo_path}")
345353

346354

347-
def has_kernel(repo_id: str, revision: str | None = None, version: int | str | None = None) -> bool:
355+
def has_kernel(
356+
repo_id: str, revision: str | None = None, version: int | str | None = None
357+
) -> bool:
348358
"""
349359
Check whether a kernel build exists for the current environment (Torch version and compute framework).
350360
@@ -417,7 +427,9 @@ def load_kernel(repo_id: str, *, lockfile: Path | None) -> ModuleType:
417427
)
418428

419429
try:
420-
package_name, variant_path = _find_kernel_in_repo_path(repo_path, package_name, variant_locks=None)
430+
package_name, variant_path = _find_kernel_in_repo_path(
431+
repo_path, package_name, variant_locks=None
432+
)
421433
return _import_from_path(package_name, variant_path)
422434
except FileNotFoundError:
423435
raise FileNotFoundError(
@@ -443,7 +455,9 @@ def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleTyp
443455
if locked_sha is None:
444456
raise ValueError(f"Kernel `{repo_id}` is not locked")
445457

446-
package_name, variant_path = install_kernel(repo_id, locked_sha, local_files_only=local_files_only)
458+
package_name, variant_path = install_kernel(
459+
repo_id, locked_sha, local_files_only=local_files_only
460+
)
447461

448462
return _import_from_path(package_name, variant_path)
449463

kernels/tests/test_deps.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@ def test_python_deps(dependency):
1111
if must_raise:
1212
with pytest.raises(
1313
ImportError,
14-
match=r"Kernel requires Python dependency `(einops|nvidia-cutlass-dsl)`",
14+
match=r"Kernel module `python_dep` requires Python dependency `(einops|nvidia-cutlass-dsl)`",
1515
):
1616
get_kernel("kernels-test/python-dep")
1717
else:
1818
get_kernel("kernels-test/python-dep")
1919

2020

2121
def test_illegal_dep():
22-
with pytest.raises(ValueError, match=r"Invalid dependency: kepler-22b"):
22+
with pytest.raises(
23+
ValueError, match=r"Kernel module `python_invalid_dep` uses.*kepler-22b"
24+
):
2325
get_kernel("kernels-test/python-invalid-dep")

0 commit comments

Comments
 (0)