Skip to content

Commit 1fed5dd

Browse files
committed
precommit
Signed-off-by: NickLucche <[email protected]>
1 parent bbf097f commit 1fed5dd

File tree

3 files changed

+50
-24
lines changed

3 files changed

+50
-24
lines changed

vllm/lora/layers/base_linear.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,12 @@ def apply(
157157
# In transformers backend, x and output have extra batch dimension like
158158
# (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
159159
# therefore we need to flatten the batch dimensions.
160-
if (x.shape[0] ==1 and x.ndim == 3 and
161-
output.shape[0] == 1 and output.ndim == 3):
160+
if (
161+
x.shape[0] == 1
162+
and x.ndim == 3
163+
and output.shape[0] == 1
164+
and output.ndim == 3
165+
):
162166
output = output.flatten(0, 1)
163167
x = x.flatten(0, 1)
164168

vllm/model_executor/models/gemma3n.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020

2121
import torch
2222
from torch import nn
23-
2423
from transformers.models.gemma3n.configuration_gemma3n import Gemma3nTextConfig
24+
2525
from vllm.attention import Attention
2626
from vllm.compilation.decorators import support_torch_compile
2727
from vllm.config import CacheConfig, VllmConfig
@@ -53,8 +53,13 @@
5353
from vllm.v1.attention.backends.utils import KVSharingFastPrefillMetadata
5454

5555
from .interfaces import SupportsLoRA, SupportsQuant
56-
from .utils import (AutoWeightsLoader, extract_layer_index,
57-
is_pp_missing_parameter, make_layers, maybe_prefix)
56+
from .utils import (
57+
AutoWeightsLoader,
58+
extract_layer_index,
59+
is_pp_missing_parameter,
60+
make_layers,
61+
maybe_prefix,
62+
)
5863

5964
logger = init_logger(__name__)
6065

vllm/model_executor/models/gemma3n_mm.py

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,19 @@
5555
from vllm.sequence import IntermediateTensors
5656
from vllm.utils.tensor_schema import TensorSchema, TensorShape
5757

58-
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
59-
SupportsMultiModal, SupportsTranscription)
60-
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
61-
init_vllm_registered_model, maybe_prefix)
58+
from .interfaces import (
59+
MultiModalEmbeddings,
60+
SupportsLoRA,
61+
SupportsMultiModal,
62+
SupportsTranscription,
63+
)
64+
from .utils import (
65+
AutoWeightsLoader,
66+
WeightsMapper,
67+
flatten_bn,
68+
init_vllm_registered_model,
69+
maybe_prefix,
70+
)
6271

6372
logger = init_logger(__name__)
6473

@@ -397,8 +406,11 @@ def __init__(
397406
self.vocab_size = multimodal_config.vocab_size
398407
self.text_hidden_size = text_config.hidden_size
399408

400-
lora_vocab = (lora_config.lora_extra_vocab_size *
401-
(lora_config.max_loras or 1)) if lora_config else 0
409+
lora_vocab = (
410+
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
411+
if lora_config
412+
else 0
413+
)
402414
self.vocab_size = self.vocab_size + lora_vocab
403415

404416
self.embedding = VocabParallelEmbedding(
@@ -446,7 +458,8 @@ def forward(
446458
""" # noqa: E501
447459
if (input_ids is None) ^ (inputs_embeds is not None):
448460
raise ValueError(
449-
"You must specify exactly one of input_ids or inputs_embeds")
461+
"You must specify exactly one of input_ids or inputs_embeds"
462+
)
450463
if inputs_embeds is not None:
451464
emb_norm = self.soft_embedding_norm(inputs_embeds)
452465
else:
@@ -457,11 +470,14 @@ def forward(
457470
return self.embedding_post_projection_norm(emb_norm_proj)
458471

459472

460-
@MULTIMODAL_REGISTRY.register_processor(Gemma3nMultiModalProcessor,
461-
info=Gemma3nProcessingInfo,
462-
dummy_inputs=Gemma3nDummyInputsBuilder)
463-
class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
464-
SupportsTranscription, SupportsLoRA):
473+
@MULTIMODAL_REGISTRY.register_processor(
474+
Gemma3nMultiModalProcessor,
475+
info=Gemma3nProcessingInfo,
476+
dummy_inputs=Gemma3nDummyInputsBuilder,
477+
)
478+
class Gemma3nForConditionalGeneration(
479+
nn.Module, SupportsMultiModal, SupportsTranscription, SupportsLoRA
480+
):
465481
merge_by_field_config = True
466482
supported_languages = ISO639_1_SUPPORTED_LANGS
467483

@@ -504,12 +520,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
504520

505521
self.vision_tower = AutoModel.from_config(config=config.vision_config)
506522
self.audio_tower = AutoModel.from_config(config=config.audio_config)
507-
self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config,
508-
config.text_config,
509-
self.lora_config)
510-
self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config,
511-
config.text_config,
512-
self.lora_config)
523+
self.embed_vision = Gemma3nMultimodalEmbedder(
524+
config.vision_config, config.text_config, self.lora_config
525+
)
526+
self.embed_audio = Gemma3nMultimodalEmbedder(
527+
config.audio_config, config.text_config, self.lora_config
528+
)
513529

514530
self.language_model: nn.Module = init_vllm_registered_model(
515531
vllm_config=vllm_config,
@@ -744,7 +760,8 @@ def get_mm_mapping(self) -> MultiModelKeys:
744760
return MultiModelKeys.from_string_field(
745761
language_model="language_model",
746762
connector="multi_modal_projector",
747-
tower_model=["vision_tower", "audio_tower"])
763+
tower_model=["vision_tower", "audio_tower"],
764+
)
748765

749766
@classmethod
750767
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:

0 commit comments

Comments
 (0)