@@ -434,6 +434,7 @@ class FusionDefinitionWrapper:
434434 enable_options : None | list [str ] = None
435435 disable_options : None | list [str ] = None
436436
437+ @annotate_for_profile ("FusionDefinitionWrapper.__call__" )
437438 def __call__ (self , * args ):
438439 fd = self .get_fd (self .to_descriptors (args ))
439440 self .last_used = fd
@@ -451,18 +452,10 @@ def __call__(self, *args):
451452 if hasattr (fd , "_selected_device" ):
452453 kwargs ["device" ] = fd ._selected_device
453454
454- if nvfuser_version () >= LooseVersion ("0.2.23" ):
455- # nvFuser expects empty list instead of None values.
456- kwargs ["_enable_options" ] = self .enable_options if self .enable_options is not None else []
457- kwargs ["_disable_options" ] = self .disable_options if self .disable_options is not None else []
458-
459- elif self .enable_options or self .disable_options :
460- warnings .warn (
461- f"nv_enable_options/nv_disable_options require nvFuser version 0.2.23 and above, found version { nvfuser_version ()} . These options will be ignored."
462- )
463-
464455 with annotate_for_profile (self .name ):
465- return fd .execute (args , ** kwargs )
456+ return fd .execute (
457+ args , _enable_options = self .enable_options , _disable_options = self .disable_options , ** kwargs
458+ )
466459
467460 def __repr__ (self ):
468461 return f"FusionDefinitionWrapper({ self .name } )"
@@ -558,9 +551,9 @@ def create_fusion_definition_wrapper(
558551 store_inputs_meta : None | bool = get_compile_option (
559552 "nv_store_fusion_inputs_meta" , "Allow nvFuser to store fusion inputs metadata for repro."
560553 )
561- enable_options : None | list [str ] = get_compile_option ("nv_enable_options" , "List of NVFUSER_ENABLE options to set." )
562- disable_options : None | list [str ] = get_compile_option (
563- "nv_disable_options" , "List of NVFUSER_DISABLE options to set."
554+ enable_options : list [str ] = get_compile_option ("nv_enable_options" , "List of NVFUSER_ENABLE options to set." ) or []
555+ disable_options : list [str ] = (
556+ get_compile_option ( "nv_disable_options" , "List of NVFUSER_DISABLE options to set." ) or []
564557 )
565558
566559 tensor_indices = []
@@ -2698,3 +2691,10 @@ def embedding(
26982691
26992692register_supported (PrimIDs .EMBEDDING , embedding , _embedding_check )
27002693register_supported (ltorch .embedding , embedding , _embedding_check )
2694+
2695+
2696+ # At module/class level
2697+ NVFUSER_SUPPORTS_OPTIONS = nvfuser_version () >= LooseVersion ("0.2.23" )
2698+ assert (
2699+ NVFUSER_SUPPORTS_OPTIONS
2700+ ), f"Installed version of nvFuser { nvfuser_version ()} is not supported, please upgrade to 0.2.23 or later."
0 commit comments