Skip to content

Commit 4627db8

Browse files
Remove version comparison from nvfuserex to decrease overhead (Lightning-AI#1840)
Removes `LooseVersion` comparison from the FusionDefinitionWrapper.\_\_call__. Co-authored-by: Ivan Yashchuk <[email protected]>
1 parent 8615c62 commit 4627db8

File tree

2 files changed

+16
-16
lines changed

2 files changed

+16
-16
lines changed

.azure/notebook-runs.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ jobs:
1515
- job: jupyter
1616
strategy:
1717
matrix:
18-
"ubuntu22.04 | cuda 12.1 | torch 2.4.0":
19-
docker-image: "ubuntu22.04-cuda12.1.1-cudnn-fe1.5.2-py3.10-pt_2.4.0-dev"
18+
"ubuntu22.04 | cuda 12.1 | torch 2.5.1":
19+
docker-image: "ubuntu22.04-cuda12.1.1-cudnn-fe1.5.2-py3.10-pt_2.5.1-dev"
2020
CUDA_VERSION_MM: "121"
2121
"ubuntu22.04 | cuda 12.1 | torch-nightly":
2222
docker-image: "ubuntu22.04-cuda12.1.1-cudnn-fe1.5.2-py3.10-pt_main-dev"

thunder/executors/nvfuserex_impl.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,7 @@ class FusionDefinitionWrapper:
434434
enable_options: None | list[str] = None
435435
disable_options: None | list[str] = None
436436

437+
@annotate_for_profile("FusionDefinitionWrapper.__call__")
437438
def __call__(self, *args):
438439
fd = self.get_fd(self.to_descriptors(args))
439440
self.last_used = fd
@@ -451,18 +452,10 @@ def __call__(self, *args):
451452
if hasattr(fd, "_selected_device"):
452453
kwargs["device"] = fd._selected_device
453454

454-
if nvfuser_version() >= LooseVersion("0.2.23"):
455-
# nvFuser expects empty list instead of None values.
456-
kwargs["_enable_options"] = self.enable_options if self.enable_options is not None else []
457-
kwargs["_disable_options"] = self.disable_options if self.disable_options is not None else []
458-
459-
elif self.enable_options or self.disable_options:
460-
warnings.warn(
461-
f"nv_enable_options/nv_disable_options require nvFuser version 0.2.23 and above, found version {nvfuser_version()}. These options will be ignored."
462-
)
463-
464455
with annotate_for_profile(self.name):
465-
return fd.execute(args, **kwargs)
456+
return fd.execute(
457+
args, _enable_options=self.enable_options, _disable_options=self.disable_options, **kwargs
458+
)
466459

467460
def __repr__(self):
468461
return f"FusionDefinitionWrapper({self.name})"
@@ -558,9 +551,9 @@ def create_fusion_definition_wrapper(
558551
store_inputs_meta: None | bool = get_compile_option(
559552
"nv_store_fusion_inputs_meta", "Allow nvFuser to store fusion inputs metadata for repro."
560553
)
561-
enable_options: None | list[str] = get_compile_option("nv_enable_options", "List of NVFUSER_ENABLE options to set.")
562-
disable_options: None | list[str] = get_compile_option(
563-
"nv_disable_options", "List of NVFUSER_DISABLE options to set."
554+
enable_options: list[str] = get_compile_option("nv_enable_options", "List of NVFUSER_ENABLE options to set.") or []
555+
disable_options: list[str] = (
556+
get_compile_option("nv_disable_options", "List of NVFUSER_DISABLE options to set.") or []
564557
)
565558

566559
tensor_indices = []
@@ -2698,3 +2691,10 @@ def embedding(
26982691

26992692
register_supported(PrimIDs.EMBEDDING, embedding, _embedding_check)
27002693
register_supported(ltorch.embedding, embedding, _embedding_check)
2694+
2695+
2696+
# At module/class level
2697+
NVFUSER_SUPPORTS_OPTIONS = nvfuser_version() >= LooseVersion("0.2.23")
2698+
assert (
2699+
NVFUSER_SUPPORTS_OPTIONS
2700+
), f"Installed version of nvFuser {nvfuser_version()} is not supported, please upgrade to 0.2.23 or later."

0 commit comments

Comments
 (0)