Skip to content

Commit f7030df

Browse files
authored
[Core][LoRA][1/N] Add LoRA for EncoderDecoderModelRunner (#15990)
Signed-off-by: Jee Jee Li <[email protected]>
1 parent 905e91e commit f7030df

File tree

3 files changed

+41
-1
lines changed

3 files changed

+41
-1
lines changed

vllm/lora/layers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -866,6 +866,11 @@ def can_replace_layer(
866866
and len(packed_modules_list) == 3)
867867

868868

869+
#TODO: Implement this
870+
class QKVCrossParallelLinearWithLoRA(BaseLayerWithLoRA):
871+
pass
872+
873+
869874
class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
870875

871876
def __init__(self, base_layer: RowParallelLinear) -> None:

vllm/model_executor/models/mllama.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
5353
from vllm.model_executor.model_loader.weight_utils import (
5454
default_weight_loader, maybe_remap_kv_scale_name)
55+
from vllm.model_executor.models.module_mapping import MultiModelKeys
5556
from vllm.model_executor.sampling_metadata import SamplingMetadata
5657
from vllm.multimodal import MULTIMODAL_REGISTRY
5758
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
@@ -1181,6 +1182,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
11811182
super().__init__()
11821183
config: MllamaConfig = vllm_config.model_config.hf_config
11831184
quant_config = vllm_config.quant_config
1185+
self.config = config
11841186
self.quant_config = quant_config
11851187
self.vocab_size = config.text_config.vocab_size
11861188
self.hidden_size = config.text_config.hidden_size
@@ -1517,6 +1519,15 @@ def load_weights(self, weights: Iterable[Tuple[str,
15171519
updated_params.add(name)
15181520
return updated_params
15191521

1522+
def get_mm_mapping(self) -> MultiModelKeys:
1523+
"""
1524+
Get the module prefix in multimodal models
1525+
"""
1526+
return MultiModelKeys.from_string_field(
1527+
language_model="language_model",
1528+
connector="multi_modal_projector",
1529+
tower_model="vision_model")
1530+
15201531

15211532
def skip_attention_mask(sparse_mask: List[List[int]]) -> bool:
15221533
for mask in sparse_mask:

vllm/worker/enc_dec_model_runner.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from vllm.forward_context import set_forward_context
1717
from vllm.inputs import INPUT_REGISTRY, InputRegistry
1818
from vllm.logger import init_logger
19+
from vllm.lora.request import LoRARequest
1920
from vllm.model_executor import SamplingMetadata
2021
from vllm.model_executor.layers.sampler import SamplerOutput
2122
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
@@ -34,6 +35,7 @@
3435
from vllm.worker.utils import assert_enc_dec_mr_supported_scenario
3536

3637
logger = init_logger(__name__)
38+
LORA_WARMUP_RANK = 8
3739

3840

3941
@dataclasses.dataclass(frozen=True)
@@ -160,7 +162,11 @@ def execute_model(
160162
if num_steps > 1:
161163
raise ValueError("num_steps > 1 is not supported in "
162164
"EncoderDecoderModelRunner")
163-
165+
if self.lora_config:
166+
assert model_input.lora_requests is not None
167+
assert model_input.lora_mapping is not None
168+
self.set_active_loras(model_input.lora_requests,
169+
model_input.lora_mapping)
164170
if (model_input.attn_metadata is not None
165171
and model_input.attn_metadata.prefill_metadata is None
166172
and model_input.attn_metadata.decode_metadata.use_cuda_graph):
@@ -268,6 +274,22 @@ def profile_run(self) -> None:
268274
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
269275
max_num_seqs = self.scheduler_config.max_num_seqs
270276

277+
# This represents the maximum number of different requests
278+
# that will have unique loras, and therefore the max amount of
279+
# memory consumption. Create dummy lora request copies from the
280+
# lora request passed in, which contains a lora from the lora
281+
# warmup path.
282+
dummy_lora_requests: List[LoRARequest] = []
283+
dummy_lora_requests_per_seq: List[LoRARequest] = []
284+
if self.lora_config:
285+
dummy_lora_requests = self._add_dummy_loras(
286+
self.lora_config.max_loras)
287+
assert len(dummy_lora_requests) == self.lora_config.max_loras
288+
dummy_lora_requests_per_seq = [
289+
dummy_lora_requests[idx % len(dummy_lora_requests)]
290+
for idx in range(max_num_seqs)
291+
]
292+
271293
# Profile memory usage with max_num_sequences sequences and the total
272294
# number of tokens equal to max_num_batched_tokens.
273295
seqs: List[SequenceGroupMetadata] = []
@@ -315,6 +337,8 @@ def profile_run(self) -> None:
315337
block_tables=None,
316338
encoder_seq_data=encoder_dummy_data.seq_data,
317339
cross_block_table=None,
340+
lora_request=dummy_lora_requests_per_seq[group_id]
341+
if dummy_lora_requests_per_seq else None,
318342
multi_modal_data=decoder_dummy_data.multi_modal_data
319343
or encoder_dummy_data.multi_modal_data,
320344
multi_modal_placeholders=decoder_dummy_data.

0 commit comments

Comments
 (0)