3737from vllm .transformers_utils .tokenizer_group import init_tokenizer_from_configs
3838from vllm .utils import (STR_DTYPE_TO_TORCH_DTYPE , LayerBlockType , cdiv ,
3939 is_pin_memory_available , LazyLoader )
40- from vllm_gaudi .utils import is_fake_hpu
40+ from vllm_gaudi .utils import HPUCompileConfig , is_fake_hpu
4141from vllm_gaudi .v1 .attention .backends .hpu_attn import HPUAttentionMetadataV1
4242from vllm .v1 .kv_cache_interface import (FullAttentionSpec , KVCacheConfig ,
4343 KVCacheSpec )
@@ -1974,28 +1974,40 @@ def load_model(self) -> None:
19741974 self .model_memory_usage / float (2 ** 30 ))
19751975
19761976 def _maybe_compile (self , * args , ** kwargs ):
1977- if not is_fake_hpu () and not htorch .utils .internal .is_lazy (
1978- ) and not self .vllm_config .model_config .enforce_eager :
1979- if os .getenv ('VLLM_REGIONAL_COMPILATION' ,
1980- 'true' ).strip ().lower () in ("1" , "true" ):
1981- compiled_methods = [
1982- '_update_metadata' , '_rotary_prepare_cos_sin'
1983- ]
1984- for method_name in compiled_methods :
1985- method = getattr (self .model , method_name )
1986- if method is not None :
1987- self ._compile_region (self .model , method_name , method )
1977+ """Entrypoint for a torch.compilation of the model"""
1978+ if (not is_fake_hpu () and not htorch .utils .internal .is_lazy ()
1979+ and not self .vllm_config .model_config .enforce_eager ):
1980+ self .compile_config = HPUCompileConfig ()
1981+ if self .compile_config .regional_compilation :
1982+ self ._compile_methods ()
19881983 self .regional_compilation_layers_list = [
19891984 RMSNorm , VocabParallelEmbedding
19901985 ]
19911986 self ._regional_compilation (self .model )
19921987 else :
19931988 self .model = self ._compile (self .model )
19941989
1990+ def _compile_methods (self ):
1991+ """
1992+ Compile methods which are not part of the compiled model i.e. those
1993+ which will not be compiled during model's compilation.
1994+ """
1995+ compiled_methods = ['_update_metadata' , '_rotary_prepare_cos_sin' ]
1996+ for method_name in compiled_methods :
1997+ method = getattr (self .model , method_name )
1998+ if method is not None :
1999+ self ._compile_region (self .model , method_name , method )
2000+
19952001 def _regional_compilation (self ,
19962002 module ,
19972003 parent_module = None ,
19982004 module_name = None ):
2005+ """
2006+ Recursively traverses a PyTorch module and compiles its regions, which
2007+ can be one of two:
2008+ 1. Children of the nn.ModuleList
2009+ 2. Member of regional_compilation_layers_list
2010+ """
19992011 if isinstance (module , torch .nn .ModuleList ):
20002012 for children_name , children_module in module .named_children ():
20012013 self ._compile_region (module , children_name , children_module )
@@ -2017,24 +2029,7 @@ def _compile_region(self, model, name, module):
20172029 setattr (model , name , module )
20182030
20192031 def _compile (self , module ):
2020- if not hasattr (self , '_compile_config' ):
2021- fullgraph = os .getenv ('VLLM_T_COMPILE_FULLGRAPH' ,
2022- 'false' ).strip ().lower () in ("1" , "true" )
2023- dynamic = os .getenv ('VLLM_T_COMPILE_DYNAMIC_SHAPES' ,
2024- 'false' ).strip ().lower () in ("1" , "true" )
2025- self ._compile_config = {'fullgraph' : fullgraph , 'dynamic' : dynamic }
2026- fullgraph = self ._compile_config ['fullgraph' ]
2027- dynamic = self ._compile_config ['dynamic' ]
2028- if dynamic :
2029- return torch .compile (module ,
2030- backend = 'hpu_backend' ,
2031- fullgraph = fullgraph ,
2032- options = {"force_static_compile" : True })
2033- else :
2034- return torch .compile (module ,
2035- backend = 'hpu_backend' ,
2036- fullgraph = fullgraph ,
2037- dynamic = False )
2032+ return torch .compile (module , ** self .compile_config .get_compile_args ())
20382033
20392034 def _use_graphs (self ):
20402035 return not self .model_config .enforce_eager
@@ -2352,8 +2347,7 @@ def warmup_model(self) -> None:
23522347
23532348 if not htorch .utils .internal .is_lazy (
23542349 ) and not self .model_config .enforce_eager :
2355- multiplier = 3 if os .getenv ('VLLM_REGIONAL_COMPILATION' ,
2356- 'true' ).lower () in ('1' , 'true' ) else 1
2350+ multiplier = 5 if self .compile_config .regional_compilation else 1
23572351 cache_size_limit = 1 + multiplier * (
23582352 len (self .bucketing_manager .prompt_buckets ) +
23592353 len (self .bucketing_manager .decode_buckets ))
0 commit comments