Skip to content

Commit cfe3b2f

Browse files
Include and exclude patterns in nav.bundle.save
1 parent 712fb04 commit cfe3b2f

File tree

3 files changed

+177
-96
lines changed

3 files changed

+177
-96
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ limitations under the License.
2323
- new: Improved TorchCompile performance for repeated compilations using TORCHINDUCTOR_CACHE_DIR environment variable
2424
- new: Global context with scoped variables - temporary context variables
2525
- new: Added new context variables `INPLACE_OPTIMIZE_WORKSPACE_CONTEXT_KEY` and `INPLACE_OPTIMIZE_MODULE_GRAPH_ID_CONTEXT_KEY`
26+
- new: nav.bundle.save now has include and exclude patterns for fine grained files selection
2627
- change: Install the TensorRT package for architectures other than x86_64
2728
- change: Disable conversion fallback for TensorRT paths and expose control option in custom config
2829
- fix: Correctness command relative tolerance formula

model_navigator/inplace/bundle.py

Lines changed: 65 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -194,14 +194,27 @@ def __str__(self) -> str:
194194
"""Name of the selection strategy."""
195195
return self.__class__.__name__
196196

197+
@abc.abstractmethod
198+
def modules_selected(self) -> List[str]:
199+
"""Selects modules to bundle."""
200+
raise NotImplementedError
201+
197202

198203
class AllModulesSelection(BundleModuleSelection):
199204
"""Selects all modules from cache for bundling."""
200205

206+
def modules_selected(self) -> List[str]:
207+
"""Selects all modules from cache for bundling."""
208+
return _all_modules_names()
209+
201210

202211
class RegisteredModulesSelection(BundleModuleSelection):
203212
"""Selects only registered modules from cache for bundling."""
204213

214+
def modules_selected(self) -> List[str]:
215+
"""Selects only registered modules from cache for bundling."""
216+
return _registered_modules_names()
217+
205218

206219
class BestRunnersSelection(BundleModuleSelection):
207220
"""Selects only best runners from registered modules for bundling."""
@@ -217,6 +230,10 @@ def __init__(self, runner_selection_strategies: Optional[List[RuntimeSearchStrat
217230
super().__init__()
218231
self.runner_selection_strategies = runner_selection_strategies or DEFAULT_RUNTIME_STRATEGIES
219232

233+
def modules_selected(self) -> List[str]:
234+
"""Selects only best runners from registered modules for bundling."""
235+
return _only_module_best_runner(self.runner_selection_strategies)
236+
220237

221238
class ModulesByNameSelection(BundleModuleSelection):
222239
"""Sometimes user may want to save only specific modules."""
@@ -229,38 +246,52 @@ def __init__(self, module_names: List[str]) -> None:
229246
"""
230247
self.module_names = module_names
231248

249+
def modules_selected(self) -> List[str]:
250+
"""Selects only selected registered modules for bundling."""
251+
return _modules_by_name(self.module_names)
252+
232253

233254
def save(
234255
bundle_path: Union[str, Path],
235256
modules: Optional[BundleModuleSelection] = None,
236257
tags: Optional[List[str]] = None,
258+
include_patterns: Optional[List[str]] = None,
259+
exclude_patterns: Optional[List[str]] = None,
237260
):
238261
"""Saves cache bundle to archive for easy storage.
239262
240263
Args:
241264
bundle_path: Where to save bundle file
242265
modules: Strategy for selecting modules. @see BundleModuleSelection and subclasses Defaults to BestRunnersSelection with MaxThroughputAndMinLatencyStrategy runners.
243266
tags: a set of tags, for better bundle identification and selection. Defaults to None.
267+
include_patterns: List of regex patterns to include.
268+
If provided, only files matching at least one pattern will be included.
269+
exclude_patterns: List of regex patterns to exclude.
270+
Files matching any of these patterns will be excluded.
244271
245272
Raises:
246273
ModelNavigatorModuleNotOptimizedError: When selected modules are not optimized yet
274+
ValueError: When include_patterns and exclude_patterns are provided at the same time
247275
"""
248276
modules = modules or BestRunnersSelection(DEFAULT_RUNTIME_STRATEGIES)
249277
cache_dir = inplace_cache_dir()
278+
file_filter = _create_file_filter(include_patterns, exclude_patterns)
250279

251280
# saving to temporary file and then moving to final location to avoid corrupted files
252281
with TemporaryDirectory() as tmp_dir:
253282
tmp_zip = Path(tmp_dir) / "bundle.nav"
254283
with zipfile.ZipFile(tmp_zip, "w") as zip_file:
255-
for entry in _selected_cache_entries(modules):
284+
for entry in modules.modules_selected():
256285
entry_path = cache_dir / entry
257286

258-
if entry_path.is_file():
287+
if entry_path.is_file() and file_filter(entry_path):
259288
zip_file.write(entry_path, entry)
260-
else:
261-
for dirpath, _, filenames in os.walk(entry_path): # Path.walk() since 3.12
262-
for filename in filenames:
263-
file_path = Path(dirpath) / filename
289+
continue
290+
291+
for dirpath, _, filenames in os.walk(entry_path): # Path.walk() since 3.12
292+
for filename in filenames:
293+
file_path = Path(dirpath) / filename
294+
if file_path.is_file() and file_filter(file_path):
264295
zip_file.write(file_path, file_path.relative_to(cache_dir))
265296

266297
# lastly adding tags to the bundle
@@ -269,6 +300,34 @@ def save(
269300
shutil.copy(tmp_zip, bundle_path)
270301

271302

303+
def _create_file_filter(include_patterns: Optional[List[str]] = None, exclude_patterns: Optional[List[str]] = None):
304+
include_patterns = include_patterns or []
305+
exclude_patterns = exclude_patterns or []
306+
307+
if not include_patterns and not exclude_patterns:
308+
return lambda _: True
309+
310+
if include_patterns and exclude_patterns:
311+
raise ValueError(
312+
"include_patterns and exclude_patterns cannot be provided at the same time. Use only one filtering method."
313+
)
314+
315+
import re
316+
317+
include_compiled = [re.compile(pattern) for pattern in include_patterns]
318+
exclude_compiled = [re.compile(pattern) for pattern in exclude_patterns]
319+
320+
def should_include(path: Path) -> bool:
321+
path = str(path)
322+
if include_compiled:
323+
return any(pattern.search(path) for pattern in include_compiled)
324+
if exclude_compiled:
325+
return not any(pattern.search(path) for pattern in exclude_compiled)
326+
return True
327+
328+
return should_include
329+
330+
272331
def _only_module_best_runner(strategies: List[RuntimeSearchStrategy]) -> List[str]:
273332
cache_dir = inplace_cache_dir()
274333

@@ -316,19 +375,6 @@ def _modules_by_name(module_names: List[str]) -> List[str]:
316375
return [name for name in _registered_modules_names() if name in module_names]
317376

318377

319-
def _selected_cache_entries(select_modules: BundleModuleSelection) -> List[str]:
320-
if isinstance(select_modules, AllModulesSelection):
321-
return _all_modules_names()
322-
if isinstance(select_modules, RegisteredModulesSelection):
323-
return _registered_modules_names()
324-
if isinstance(select_modules, BestRunnersSelection):
325-
return _only_module_best_runner(select_modules.runner_selection_strategies)
326-
if isinstance(select_modules, ModulesByNameSelection):
327-
return _modules_by_name(select_modules.module_names)
328-
329-
raise ValueError(f"Unknown module selection strategy: {select_modules}")
330-
331-
332378
def _raise_if_module_not_optimized(name, module):
333379
if not module.is_optimized:
334380
raise ModelNavigatorModuleNotOptimizedError(

0 commit comments

Comments
 (0)