Skip to content

Commit ab65f9b

Browse files
author
Kacper Pietkun
authored
Add t.compile config (#62)
Signed-off-by: Kacper Pietkun <kpietkun@habana.ai>
1 parent 39bce5d commit ab65f9b

File tree

5 files changed

+112
-34
lines changed

5 files changed

+112
-34
lines changed

tests/unit_tests/worker/test_hpu_model_runner.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import pytest
55
import torch
66
import habana_frameworks.torch # noqa: F401
7+
from habana_frameworks.torch.utils.internal import is_lazy
8+
from vllm.model_executor.model_loader import get_model
79

810
from vllm.attention import Attention
911
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
@@ -667,3 +669,38 @@ def test_init_kv_cache_with_kv_sharing_valid():
667669
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
668670
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
669671
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
672+
673+
674+
@pytest.mark.skipif(is_lazy(),
675+
reason="Test skipped because lazy mode is enabled.")
676+
def test_model_torch_regional_compilation(dist_init, model_runner):
677+
from vllm_gaudi.utils import HPUCompileConfig
678+
from vllm.model_executor.models.opt import OPTDecoderLayer
679+
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding # noqa
680+
from torch.nn.modules.normalization import LayerNorm
681+
from torch._dynamo.eval_frame import OptimizedModule
682+
683+
def assert_compilation(model, layer_name, module):
684+
submodule = model.get_submodule(layer_name)
685+
assert isinstance(submodule, OptimizedModule), (
686+
f"Layer: '{module.__name__}' was not wrapped with OptimizedModule" # noqa
687+
)
688+
assert isinstance(submodule._orig_mod, module), (
689+
f"_orig_mod is different from the original module: '{module.__name__}'" # noqa
690+
)
691+
692+
vllm_config = get_vllm_config()
693+
model = get_model(vllm_config=vllm_config)
694+
model_runner.compile_config = HPUCompileConfig()
695+
model_runner.regional_compilation_layers_list = [
696+
LayerNorm, VocabParallelEmbedding
697+
]
698+
699+
model_runner._regional_compilation(model)
700+
701+
for i in range(len(model.get_submodule("model.decoder.layers"))):
702+
assert_compilation(model, f"model.decoder.layers.{i}", OPTDecoderLayer)
703+
assert_compilation(model, "lm_head", VocabParallelEmbedding)
704+
assert_compilation(model, "model.decoder.final_layer_norm", LayerNorm)
705+
assert_compilation(model, "model.decoder.embed_tokens",
706+
VocabParallelEmbedding)

vllm_gaudi/extension/features.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,5 +71,8 @@ def get_features():
7171
Value('exponential_bucketing', True, env_var='VLLM_EXPONENTIAL_BUCKETING'),
7272
Value('linear_bucketing', True),
7373
Value('bucketing_strategy', FirstEnabled(*bucketing_strategies), env_var_type=choice(*bucketing_strategies)),
74+
Value('regional_compilation', True, env_var='VLLM_T_COMPILE_REGIONAL_COMPILATION', env_var_type=boolean),
75+
Value('dynamic_shapes_compilation', False, env_var='VLLM_T_COMPILE_DYNAMIC_SHAPES', env_var_type=boolean),
76+
Value('fullgraph_compilation', False, env_var='VLLM_T_COMPILE_FULLGRAPH', env_var_type=boolean),
7477
]
7578
return split_values_and_flags(features)

vllm_gaudi/platform.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import TYPE_CHECKING, Any, Optional
55

66
import torch
7+
import habana_frameworks.torch as htorch
78

89
from vllm import envs
910

@@ -144,7 +145,7 @@ def set_torch_compile(cls) -> None:
144145
# Eager backend (PT_HPU_LAZY_MODE = 0) must be selected for
145146
# torch.compile support
146147
os.environ['PT_HPU_WEIGHT_SHARING'] = '0'
147-
is_lazy = os.environ.get('PT_HPU_LAZY_MODE', '0') == '1'
148+
is_lazy = htorch.utils.internal.is_lazy()
148149
if is_lazy:
149150
torch._dynamo.config.disable = True
150151
# NOTE multi-HPU inference with HPUGraphs (lazy-only)

vllm_gaudi/utils.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from functools import cache
22
import os
33
from vllm.utils import make_tensor_with_pad, TORCH_DTYPE_TO_NUMPY_DTYPE
4-
from typing import (Optional, TypeVar, Union)
4+
from vllm_gaudi.extension.runtime import get_config
5+
from typing import (Any, Optional, TypeVar, Union)
56
import torch
67
import numpy as np
78
import numpy.typing as npt
@@ -108,3 +109,45 @@ def make_tensor_with_pad_align(
108109
tensor = tensor.pin_memory()
109110

110111
return tensor
112+
113+
114+
class HPUCompileConfig:
115+
"""
116+
Configuration class, which holds arguments that will be
117+
passed to torch compile with HPU backend.
118+
"""
119+
120+
def __init__(self,
121+
fullgraph: Optional[bool] = None,
122+
dynamic: Optional[bool] = None):
123+
"""
124+
Allow to override the environment variables for corner case scenarios
125+
when single functions are compiled with torch.compile decorator.
126+
Env variables should not be overwritten when it comes to compilation
127+
of the whole model.
128+
"""
129+
self.fullgraph = fullgraph if fullgraph is not None else \
130+
get_config().fullgraph_compilation
131+
self.dynamic = dynamic if dynamic is not None else \
132+
get_config().dynamic_shapes_compilation
133+
self.regional_compilation = get_config().regional_compilation
134+
135+
def get_compile_args(self) -> dict[str, Any]:
136+
"""
137+
Returns a dictionary of compile arguments that can be used
138+
with torch.compile method or decorator
139+
"""
140+
if self.dynamic:
141+
return {
142+
'backend': 'hpu_backend',
143+
'fullgraph': self.fullgraph,
144+
'options': {
145+
"force_static_compile": True
146+
}
147+
}
148+
else:
149+
return {
150+
'backend': 'hpu_backend',
151+
'fullgraph': self.fullgraph,
152+
'dynamic': False
153+
}

vllm_gaudi/v1/worker/hpu_model_runner.py

Lines changed: 26 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
3838
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, cdiv,
3939
is_pin_memory_available, LazyLoader)
40-
from vllm_gaudi.utils import is_fake_hpu
40+
from vllm_gaudi.utils import HPUCompileConfig, is_fake_hpu
4141
from vllm_gaudi.v1.attention.backends.hpu_attn import HPUAttentionMetadataV1
4242
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
4343
KVCacheSpec)
@@ -1974,28 +1974,40 @@ def load_model(self) -> None:
19741974
self.model_memory_usage / float(2**30))
19751975

19761976
def _maybe_compile(self, *args, **kwargs):
1977-
if not is_fake_hpu() and not htorch.utils.internal.is_lazy(
1978-
) and not self.vllm_config.model_config.enforce_eager:
1979-
if os.getenv('VLLM_REGIONAL_COMPILATION',
1980-
'true').strip().lower() in ("1", "true"):
1981-
compiled_methods = [
1982-
'_update_metadata', '_rotary_prepare_cos_sin'
1983-
]
1984-
for method_name in compiled_methods:
1985-
method = getattr(self.model, method_name)
1986-
if method is not None:
1987-
self._compile_region(self.model, method_name, method)
1977+
"""Entrypoint for a torch.compilation of the model"""
1978+
if (not is_fake_hpu() and not htorch.utils.internal.is_lazy()
1979+
and not self.vllm_config.model_config.enforce_eager):
1980+
self.compile_config = HPUCompileConfig()
1981+
if self.compile_config.regional_compilation:
1982+
self._compile_methods()
19881983
self.regional_compilation_layers_list = [
19891984
RMSNorm, VocabParallelEmbedding
19901985
]
19911986
self._regional_compilation(self.model)
19921987
else:
19931988
self.model = self._compile(self.model)
19941989

1990+
def _compile_methods(self):
1991+
"""
1992+
Compile methods which are not part of the compiled model i.e. those
1993+
which will not be compiled during model's compilation.
1994+
"""
1995+
compiled_methods = ['_update_metadata', '_rotary_prepare_cos_sin']
1996+
for method_name in compiled_methods:
1997+
method = getattr(self.model, method_name)
1998+
if method is not None:
1999+
self._compile_region(self.model, method_name, method)
2000+
19952001
def _regional_compilation(self,
19962002
module,
19972003
parent_module=None,
19982004
module_name=None):
2005+
"""
2006+
Recursively traverses a PyTorch module and compiles its regions, which
2007+
can be one of two:
2008+
1. Children of the nn.ModuleList
2009+
2. Member of regional_compilation_layers_list
2010+
"""
19992011
if isinstance(module, torch.nn.ModuleList):
20002012
for children_name, children_module in module.named_children():
20012013
self._compile_region(module, children_name, children_module)
@@ -2017,24 +2029,7 @@ def _compile_region(self, model, name, module):
20172029
setattr(model, name, module)
20182030

20192031
def _compile(self, module):
2020-
if not hasattr(self, '_compile_config'):
2021-
fullgraph = os.getenv('VLLM_T_COMPILE_FULLGRAPH',
2022-
'false').strip().lower() in ("1", "true")
2023-
dynamic = os.getenv('VLLM_T_COMPILE_DYNAMIC_SHAPES',
2024-
'false').strip().lower() in ("1", "true")
2025-
self._compile_config = {'fullgraph': fullgraph, 'dynamic': dynamic}
2026-
fullgraph = self._compile_config['fullgraph']
2027-
dynamic = self._compile_config['dynamic']
2028-
if dynamic:
2029-
return torch.compile(module,
2030-
backend='hpu_backend',
2031-
fullgraph=fullgraph,
2032-
options={"force_static_compile": True})
2033-
else:
2034-
return torch.compile(module,
2035-
backend='hpu_backend',
2036-
fullgraph=fullgraph,
2037-
dynamic=False)
2032+
return torch.compile(module, **self.compile_config.get_compile_args())
20382033

20392034
def _use_graphs(self):
20402035
return not self.model_config.enforce_eager
@@ -2352,8 +2347,7 @@ def warmup_model(self) -> None:
23522347

23532348
if not htorch.utils.internal.is_lazy(
23542349
) and not self.model_config.enforce_eager:
2355-
multiplier = 3 if os.getenv('VLLM_REGIONAL_COMPILATION',
2356-
'true').lower() in ('1', 'true') else 1
2350+
multiplier = 5 if self.compile_config.regional_compilation else 1
23572351
cache_size_limit = 1 + multiplier * (
23582352
len(self.bucketing_manager.prompt_buckets) +
23592353
len(self.bucketing_manager.decode_buckets))

0 commit comments

Comments
 (0)