Skip to content

Commit 40eabb6

Browse files
committed
Implement windows
Signed-off-by: Cristian Le <[email protected]>
1 parent 6fec3b9 commit 40eabb6

File tree

2 files changed

+258
-7
lines changed

2 files changed

+258
-7
lines changed

src/scikit_build_core/repair_wheel/windows.py

Lines changed: 254 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,14 @@
44

55
from __future__ import annotations
66

7-
from typing import TYPE_CHECKING
7+
import dataclasses
8+
import os.path
9+
import textwrap
10+
from pathlib import Path
11+
from typing import TYPE_CHECKING, ClassVar
812

9-
from . import WheelRepairer
13+
from .._logging import logger
14+
from . import WheelRepairer, _get_buildenv_platlib
1015

1116
if TYPE_CHECKING:
1217
from ..file_api.model.codemodel import Target
@@ -18,13 +23,257 @@ def __dir__() -> list[str]:
1823
return __all__
1924

2025

26+
@dataclasses.dataclass
2127
class WindowsWheelRepairer(WheelRepairer):
2228
"""
23-
Do some windows specific magic.
29+
Patch the package and top-level python module files with ``os.add_dll_directory``.
2430
"""
2531

2632
_platform = "Windows"
2733

34+
PATCH_PY_FILE: ClassVar[str] = textwrap.dedent("""\
35+
# start scikit-build-core Windows patch
36+
def _skbuild_patch_dll_dir():
37+
import os
38+
import os.path
39+
40+
mod_dir = os.path.abspath(os.path.dirname(__file__))
41+
path_to_platlib = os.path.normpath({path_to_platlib!r})
42+
dll_paths = {dll_paths!r}
43+
for path in dll_paths:
44+
path = os.path.normpath(path)
45+
path = os.path.join(mod_dir, path_to_platlib, path)
46+
os.add_dll_directory(path)
47+
48+
_skbuild_patch_dll_dir()
49+
del _skbuild_patch_dll_dir
50+
# end scikit-build-core Windows patch
51+
""")
52+
dll_dirs: set[Path] = dataclasses.field(default_factory=set, init=False)
53+
"""All dll paths used relative to ``platlib``."""
54+
55+
def get_dll_path_from_lib(self, lib_path: Path) -> Path | None:
56+
"""Guess the dll path from lib path."""
57+
dll_path = None
58+
platlib = Path(_get_buildenv_platlib())
59+
lib_path = lib_path.relative_to(platlib)
60+
# Change the `.lib` to `.dll`
61+
if ".dll" in (suffixes := lib_path.suffixes):
62+
# In some cases like msys, they use `.dll.a`, in which case we can't use `with_suffix`
63+
# Get the right-most position of the .dll suffix
64+
dll_suffix_pos = len(suffixes) - suffixes[::-1].index(".dll") - 1
65+
# Drop all suffixes past the right-most .dll suffix
66+
suffixes = suffixes[: dll_suffix_pos + 1]
67+
dll_name = f"{lib_path.stem}{''.join(suffixes)}"
68+
else:
69+
dll_name = lib_path.with_suffix(".dll").name
70+
# Try to find the dll in the same package directory
71+
if len(lib_path.parts) > 1:
72+
pkg_dir = lib_path.parts[0]
73+
for root, _, files in os.walk(platlib / pkg_dir):
74+
if dll_name in files:
75+
dll_path = Path(root) / dll_name
76+
break
77+
else:
78+
logger.debug(
79+
"Did not find the dll file under {pkg_dir}",
80+
pkg_dir=pkg_dir,
81+
)
82+
if not dll_path:
83+
logger.debug(
84+
"Looking for {dll_name} in all platlib path.",
85+
dll_name=dll_name,
86+
)
87+
for root, _, files in os.walk(platlib):
88+
if dll_name in files:
89+
dll_path = Path(root) / dll_name
90+
break
91+
else:
92+
logger.warning(
93+
"Could not find dll file {dll_name} corresponding to {lib_path}",
94+
dll_name=dll_name,
95+
lib_path=lib_path,
96+
)
97+
return None
98+
logger.debug(
99+
"Found dll file {dll_path}",
100+
dll_path=dll_path,
101+
)
102+
return self.path_relative_site_packages(dll_path)
103+
104+
def get_library_dependencies(self, target: Target) -> list[Target]:
105+
msg = "get_library_dependencies is not generalized for Windows."
106+
raise NotImplementedError(msg)
107+
108+
def get_dependency_dll(self, target: Target) -> list[Path]:
109+
"""Get the dll due to target link dependencies."""
110+
dll_paths = []
111+
for dep in target.dependencies:
112+
dep_target = next(targ for targ in self.targets if targ.id == dep.id)
113+
if dep_target.type != "SHARED_LIBRARY":
114+
logger.debug(
115+
"Skipping dependency {dep_target} of type {type}",
116+
dep_target=dep_target.name,
117+
type=dep_target.type,
118+
)
119+
continue
120+
if not dep_target.install:
121+
logger.warning(
122+
"Dependency {dep_target} is not installed",
123+
dep_target=dep_target.name,
124+
)
125+
continue
126+
dll_artifact = next(
127+
artifact.path.name
128+
for artifact in dep_target.artifacts
129+
if artifact.path.suffix == ".dll"
130+
)
131+
for install_path in self.get_wheel_install_paths(dep_target):
132+
dep_install_path = self.install_dir / install_path
133+
if (dep_install_path / dll_artifact).exists():
134+
break
135+
else:
136+
logger.warning(
137+
"Could not find installed {dll_artifact} location in install paths: {install_path}",
138+
dll_artifact=dll_artifact,
139+
install_path=[
140+
dest.path for dest in dep_target.install.destinations
141+
],
142+
)
143+
continue
144+
dll_path = self.path_relative_site_packages(dep_install_path)
145+
dll_paths.append(dll_path)
146+
return dll_paths
147+
148+
def get_package_dll(self, target: Target) -> list[Path]:
149+
"""
150+
Get the dll due to external package linkage.
151+
152+
Have to use the guess the dll paths until the package targets are exposed.
153+
https://gitlab.kitware.com/cmake/cmake/-/issues/26755
154+
"""
155+
if not target.link:
156+
return []
157+
dll_paths = []
158+
for link_command in target.link.commandFragments:
159+
if link_command.role == "flags":
160+
if not link_command.fragment:
161+
logger.debug(
162+
"Skipping {target} link-flags: {flags}",
163+
target=target.name,
164+
flags=link_command.fragment,
165+
)
166+
continue
167+
if link_command.role != "libraries":
168+
logger.warning(
169+
"File-api link role {role} is not supported. "
170+
"Target={target}, command={command}",
171+
target=target.name,
172+
role=link_command.role,
173+
command=link_command.fragment,
174+
)
175+
continue
176+
# The remaining case should be a path
177+
try:
178+
# TODO: how to best catch if a string is a valid path?
179+
lib_path = Path(link_command.fragment)
180+
if not lib_path.is_absolute():
181+
# If the link_command is a space-separated list of libraries, this should be skipped
182+
logger.debug(
183+
"Skipping non-absolute-path library: {fragment}",
184+
fragment=link_command.fragment,
185+
)
186+
continue
187+
try:
188+
self.path_relative_site_packages(lib_path)
189+
except ValueError:
190+
logger.debug(
191+
"Skipping library outside site-package path: {lib_path}",
192+
lib_path=lib_path,
193+
)
194+
continue
195+
dll_path = self.get_dll_path_from_lib(lib_path)
196+
if not dll_path:
197+
continue
198+
dll_paths.append(dll_path.parent)
199+
except Exception as exc:
200+
logger.warning(
201+
"Could not parse link-library as a path: {fragment}\nexc = {exc}",
202+
fragment=link_command.fragment,
203+
exc=exc,
204+
)
205+
continue
206+
return dll_paths
207+
28208
def patch_target(self, target: Target) -> None:
29-
# TODO: Implement patching
30-
pass
209+
# Here we just gather all dll paths needed for each target
210+
package_dlls = self.get_package_dll(target)
211+
dependency_dlls = self.get_dependency_dll(target)
212+
if not package_dlls and not dependency_dlls:
213+
logger.warning(
214+
"No dll files found for target {target}",
215+
target=target.name,
216+
)
217+
return
218+
logger.debug(
219+
"Found dlls for target {target}:\n"
220+
"package_dlls={package_dlls}\n"
221+
"dependency_dlls={dependency_dlls}\n",
222+
target=target.name,
223+
package_dlls=package_dlls,
224+
dependency_dlls=dependency_dlls,
225+
)
226+
self.dll_dirs.update(package_dlls)
227+
self.dll_dirs.update(dependency_dlls)
228+
229+
def patch_python_file(self, file: Path) -> None:
230+
"""
231+
Patch python package or top-level module.
232+
233+
Make sure the python files have an appropriate ``os.add_dll_directory``
234+
for the scripts directory.
235+
"""
236+
assert self.dll_dirs
237+
assert all(not path.is_absolute() for path in self.dll_dirs)
238+
logger.debug(
239+
"Patching python file: {file}",
240+
file=file,
241+
)
242+
platlib = Path(self.wheel_dirs["platlib"])
243+
content = file.read_text()
244+
mod_dir = file.parent
245+
path_to_platlib = os.path.relpath(platlib, mod_dir)
246+
patch_script = self.PATCH_PY_FILE.format(
247+
path_to_platlib=path_to_platlib,
248+
dll_paths=[str(path) for path in self.dll_dirs],
249+
)
250+
# TODO: Account for the header comments, __future__.annotations, etc.
251+
with file.open("w") as f:
252+
f.write(f"{patch_script}\n" + content)
253+
254+
def repair_wheel(self) -> None:
255+
super().repair_wheel()
256+
platlib = Path(self.wheel_dirs["platlib"])
257+
if not self.dll_dirs:
258+
logger.debug(
259+
"Skipping wheel repair because no site-package dlls were found."
260+
)
261+
return
262+
logger.debug(
263+
"Patching dll directories: {dll_dirs}",
264+
dll_dirs=self.dll_dirs,
265+
)
266+
# TODO: Not handling namespace packages with this
267+
for path in platlib.iterdir():
268+
assert isinstance(path, Path)
269+
if path.is_dir():
270+
pkg_file = path / "__init__.py"
271+
if not pkg_file.exists():
272+
logger.debug(
273+
"Ignoring non-python package: {pkg_file}",
274+
pkg_file=pkg_file,
275+
)
276+
continue
277+
self.patch_python_file(pkg_file)
278+
elif path.suffix == ".py":
279+
self.patch_python_file(path)

tests/test_repair_wheel.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,10 @@ def test_full_build(
7272
wheels = list(dist.glob("*.whl"))
7373
isolated.install(*wheels)
7474

75-
isolated.run("main")
76-
isolated.module("repair_wheel")
75+
if platform.system() != "Windows":
76+
# Requires a more specialized patch
77+
isolated.run("main")
78+
isolated.module("repair_wheel")
7779
isolated.execute(
7880
"from repair_wheel._module import hello; hello()",
7981
)

0 commit comments

Comments
 (0)