Skip to content

Commit 6ae1597

Browse files
authored
[VLM] Minor space optimization for ClipVisionModel (#6436)
1 parent 22e79ee commit 6ae1597

File tree

4 files changed

+66
-39
lines changed

4 files changed

+66
-39
lines changed

vllm/model_executor/models/clip.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -214,22 +214,24 @@ class CLIPEncoder(nn.Module):
214214

215215
def __init__(self,
216216
config: CLIPVisionConfig,
217-
quant_config: Optional[QuantizationConfig] = None):
217+
quant_config: Optional[QuantizationConfig] = None,
218+
num_hidden_layers_override: Optional[int] = None):
218219
super().__init__()
219220
self.config = config
221+
222+
if num_hidden_layers_override is None:
223+
num_hidden_layers = config.num_hidden_layers
224+
else:
225+
num_hidden_layers = num_hidden_layers_override
220226
self.layers = nn.ModuleList([
221227
CLIPEncoderLayer(config=config, quant_config=quant_config)
222-
for _ in range(config.num_hidden_layers)
228+
for _ in range(num_hidden_layers)
223229
])
224230

225-
def forward(self,
226-
inputs_embeds: torch.Tensor,
227-
vision_feature_layer: int = -1):
231+
def forward(self, inputs_embeds: torch.Tensor):
228232

229-
# Encoder forward pass only up to the required layer
230-
num_layer = len(self.layers) + vision_feature_layer + 1
231233
hidden_states = inputs_embeds
232-
for encoder_layer in self.layers[:num_layer]:
234+
for encoder_layer in self.layers:
233235
hidden_states = encoder_layer(hidden_states)
234236

235237
return hidden_states
@@ -239,7 +241,8 @@ class CLIPVisionTransformer(nn.Module):
239241

240242
def __init__(self,
241243
config: CLIPVisionConfig,
242-
quant_config: Optional[QuantizationConfig] = None):
244+
quant_config: Optional[QuantizationConfig] = None,
245+
num_hidden_layers_override: Optional[int] = None):
243246
super().__init__()
244247
self.config = config
245248
embed_dim = config.hidden_size
@@ -249,18 +252,19 @@ def __init__(self,
249252
# NOTE: This typo of "layrnorm" is not fixed on purpose to match
250253
# the original transformers code and name of the model weights.
251254
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
252-
self.encoder = CLIPEncoder(config=config, quant_config=quant_config)
255+
self.encoder = CLIPEncoder(
256+
config=config,
257+
quant_config=quant_config,
258+
num_hidden_layers_override=num_hidden_layers_override)
253259

254260
def forward(
255261
self,
256262
pixel_values: torch.Tensor,
257-
vision_feature_layer: int = -1,
258263
) -> torch.Tensor:
259264

260265
hidden_states = self.embeddings(pixel_values)
261266
hidden_states = self.pre_layrnorm(hidden_states)
262-
hidden_states = self.encoder(inputs_embeds=hidden_states,
263-
vision_feature_layer=vision_feature_layer)
267+
hidden_states = self.encoder(inputs_embeds=hidden_states)
264268

265269
return hidden_states
266270

@@ -272,17 +276,17 @@ class CLIPVisionModel(nn.Module):
272276

273277
def __init__(self,
274278
config: CLIPVisionConfig,
275-
quant_config: Optional[QuantizationConfig] = None):
279+
quant_config: Optional[QuantizationConfig] = None,
280+
num_hidden_layers_override: Optional[int] = None):
276281
super().__init__()
277-
self.vision_model = CLIPVisionTransformer(config=config,
278-
quant_config=quant_config)
282+
self.vision_model = CLIPVisionTransformer(
283+
config=config,
284+
quant_config=quant_config,
285+
num_hidden_layers_override=num_hidden_layers_override)
279286

280-
def forward(self,
281-
pixel_values: Optional[torch.Tensor] = None,
282-
vision_feature_layer: int = -1):
287+
def forward(self, pixel_values: Optional[torch.Tensor] = None):
283288

284-
return self.vision_model(pixel_values=pixel_values,
285-
vision_feature_layer=vision_feature_layer)
289+
return self.vision_model(pixel_values=pixel_values)
286290

287291
@property
288292
def device(self):

vllm/model_executor/models/llava.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,17 @@ def __init__(self,
128128
self.config = config
129129
self.multimodal_config = multimodal_config
130130

131+
# Initialize the vision tower only up to the required feature layer
132+
vision_feature_layer = config.vision_feature_layer
133+
if vision_feature_layer < 0:
134+
num_hidden_layers = config.vision_config.num_hidden_layers \
135+
+ vision_feature_layer + 1
136+
else:
137+
num_hidden_layers = vision_feature_layer + 1
138+
131139
# TODO: Optionally initializes this for supporting embeddings.
132-
self.vision_tower = CLIPVisionModel(config.vision_config)
140+
self.vision_tower = CLIPVisionModel(
141+
config.vision_config, num_hidden_layers_override=num_hidden_layers)
133142
self.multi_modal_projector = LlavaMultiModalProjector(
134143
vision_hidden_size=config.vision_config.hidden_size,
135144
text_hidden_size=config.text_config.hidden_size,
@@ -193,8 +202,7 @@ def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
193202

194203
# NOTE: we skip the step to select the vision feature layer since
195204
# this is already done inside the vision tower
196-
image_features = vision_tower(pixel_values,
197-
self.config.vision_feature_layer)
205+
image_features = vision_tower(pixel_values)
198206

199207
return self._select_image_features(
200208
image_features,
@@ -333,7 +341,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
333341
break
334342
else:
335343
use_default_weight_loading = True
336-
if use_default_weight_loading:
344+
if use_default_weight_loading and name in params_dict:
337345
param = params_dict[name]
338346
weight_loader = getattr(param, "weight_loader",
339347
default_weight_loader)

vllm/model_executor/models/llava_next.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,17 @@ def __init__(self,
222222
self.config = config
223223
self.multimodal_config = multimodal_config
224224

225+
# Initialize the vision tower only up to the required feature layer
226+
vision_feature_layer = config.vision_feature_layer
227+
if vision_feature_layer < 0:
228+
num_hidden_layers = config.vision_config.num_hidden_layers \
229+
+ vision_feature_layer + 1
230+
else:
231+
num_hidden_layers = vision_feature_layer + 1
232+
225233
# TODO: Optionally initializes this for supporting embeddings.
226-
self.vision_tower = CLIPVisionModel(config=config.vision_config)
234+
self.vision_tower = CLIPVisionModel(
235+
config.vision_config, num_hidden_layers_override=num_hidden_layers)
227236
self.multi_modal_projector = LlavaMultiModalProjector(
228237
vision_hidden_size=config.vision_config.hidden_size,
229238
text_hidden_size=config.text_config.hidden_size,
@@ -312,8 +321,7 @@ def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
312321

313322
# NOTE: we skip the step to select the vision feature layer since
314323
# this is already done inside the vision tower
315-
image_features = vision_tower(pixel_values,
316-
self.config.vision_feature_layer)
324+
image_features = vision_tower(pixel_values)
317325

318326
return self._select_image_features(
319327
image_features,
@@ -561,7 +569,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
561569
break
562570
else:
563571
use_default_weight_loading = True
564-
if use_default_weight_loading:
572+
if use_default_weight_loading and name in params_dict:
565573
param = params_dict[name]
566574
weight_loader = getattr(param, "weight_loader",
567575
default_weight_loader)

vllm/model_executor/models/phi3v.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,11 @@ def __init__(self, wte=None) -> None:
8080

8181
def get_img_features(self,
8282
img_embeds: torch.FloatTensor) -> torch.FloatTensor:
83-
LAYER_IDX = self.layer_idx
8483
TYPE_FEATURE = self.type_feature
8584

8685
# NOTE: we skip the step to select the vision feature layer since
8786
# this is already done inside the img_processor
88-
img_feature = self.img_processor(img_embeds,
89-
vision_feature_layer=LAYER_IDX)
87+
img_feature = self.img_processor(img_embeds)
9088

9189
if TYPE_FEATURE == "patch":
9290
patch_feature = img_feature[:, 1:]
@@ -111,7 +109,17 @@ def __init__(self, config: PretrainedConfig, wte=None) -> None:
111109
config, 'n_embd') else config.hidden_size
112110

113111
clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
114-
self.img_processor = CLIPVisionModel(clip_config)
112+
self.layer_idx = config.img_processor.get('layer_idx', -2)
113+
114+
# Initialize the CLIP only up to the required feature layer
115+
if self.layer_idx < 0:
116+
num_hidden_layers = clip_config.num_hidden_layers + \
117+
self.layer_idx + 1
118+
else:
119+
num_hidden_layers = self.layer_idx + 1
120+
121+
self.img_processor = CLIPVisionModel(
122+
clip_config, num_hidden_layers_override=num_hidden_layers)
115123
image_dim_out = config.img_processor['image_dim_out']
116124
self.num_img_tokens = config.img_processor['num_img_tokens']
117125

@@ -142,8 +150,6 @@ def __init__(self, config: PretrainedConfig, wte=None) -> None:
142150
self.img_projection = nn.Sequential(*layers)
143151

144152
self.vocab_size = config.vocab_size
145-
146-
self.layer_idx = config.img_processor.get('layer_idx', -2)
147153
self.type_feature = config.img_processor.get('type_feature', 'patch')
148154

149155
def forward(self, input_ids: torch.LongTensor,
@@ -588,7 +594,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
588594
# Skip loading extra bias for GPTQ models.
589595
if name.endswith(".bias") and name not in params_dict:
590596
continue
591-
param = params_dict[name]
592-
weight_loader = getattr(param, "weight_loader",
593-
default_weight_loader)
594-
weight_loader(param, loaded_weight)
597+
if name in params_dict:
598+
param = params_dict[name]
599+
weight_loader = getattr(param, "weight_loader",
600+
default_weight_loader)
601+
weight_loader(param, loaded_weight)

0 commit comments

Comments
 (0)