Skip to content

Commit e4bb268

Browse files
authored
[Models] Replace all nn.Conv2d with vLLM's Conv2dLayer (#28842)
Signed-off-by: Isotr0py <[email protected]>
1 parent c64c0b7 commit e4bb268

20 files changed

+83
-45
lines changed

vllm/model_executor/layers/conv.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""Conv Layer Class."""
44

55
import math
6+
from typing import Literal
67

78
import torch
89
import torch.nn as nn
@@ -23,11 +24,11 @@ def __init__(
2324
out_channels: int,
2425
kernel_size: int | tuple[int, ...],
2526
stride: int | tuple[int, ...] = 1,
26-
padding: int | tuple[int, ...] = 0,
27+
padding: int | tuple[int, ...] | Literal["same", "valid"] = 0,
2728
dilation: int | tuple[int, ...] = 1,
2829
groups: int = 1,
2930
bias: bool = True,
30-
padding_mode: str = "zeros",
31+
padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros",
3132
*,
3233
params_dtype: torch.dtype | None = None,
3334
) -> None:
@@ -36,6 +37,22 @@ def __init__(
3637
if params_dtype is None:
3738
params_dtype = torch.get_default_dtype()
3839

40+
valid_padding_strings = {"same", "valid"}
41+
if isinstance(padding, str) and padding not in valid_padding_strings:
42+
raise ValueError(
43+
f"Invalid padding string '{padding}'. "
44+
f"Expected one of {valid_padding_strings}."
45+
)
46+
47+
if padding == "same":
48+
padding = (
49+
kernel_size // 2
50+
if isinstance(kernel_size, int)
51+
else tuple(k // 2 for k in kernel_size)
52+
)
53+
elif padding == "valid":
54+
padding = 0
55+
3956
kernel_size = (
4057
(kernel_size,) * self.num_dim
4158
if isinstance(kernel_size, int)
@@ -45,6 +62,9 @@ def __init__(
4562
padding = (padding,) * self.num_dim if isinstance(padding, int) else padding
4663
dilation = (dilation,) * self.num_dim if isinstance(dilation, int) else dilation
4764

65+
if padding == "same" and any(s != 1 for s in stride):
66+
raise ValueError("padding='same' is not supported for strided convolutions")
67+
4868
self.in_channels = in_channels
4969
self.out_channels = out_channels
5070
self.kernel_size = kernel_size

vllm/model_executor/models/aimv2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from vllm.distributed import get_tensor_model_parallel_world_size
1313
from vllm.distributed.utils import divide
1414
from vllm.model_executor.layers.activation import SiluAndMul
15+
from vllm.model_executor.layers.conv import Conv2dLayer
1516
from vllm.model_executor.layers.layernorm import RMSNorm
1617
from vllm.model_executor.layers.linear import (
1718
MergedColumnParallelLinear,
@@ -58,7 +59,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
5859
class AIMv2PatchEmbed(nn.Module):
5960
def __init__(self, config: AIMv2Config):
6061
super().__init__()
61-
self.proj = nn.Conv2d(
62+
self.proj = Conv2dLayer(
6263
config.num_channels,
6364
config.hidden_size,
6465
kernel_size=(config.patch_size, config.patch_size),

vllm/model_executor/models/blip.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from vllm.attention.layer import MultiHeadAttention
1313
from vllm.distributed import divide, get_tensor_model_parallel_world_size
1414
from vllm.model_executor.layers.activation import get_act_fn
15+
from vllm.model_executor.layers.conv import Conv2dLayer
1516
from vllm.model_executor.layers.linear import (
1617
ColumnParallelLinear,
1718
QKVParallelLinear,
@@ -47,7 +48,7 @@ def __init__(self, config: BlipVisionConfig | Blip2VisionConfig):
4748

4849
self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
4950

50-
self.patch_embedding = nn.Conv2d(
51+
self.patch_embedding = Conv2dLayer(
5152
in_channels=3,
5253
out_channels=self.embed_dim,
5354
kernel_size=self.patch_size,

vllm/model_executor/models/chameleon.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
2323
from vllm.logger import init_logger
2424
from vllm.model_executor.layers.activation import SiluAndMul
25+
from vllm.model_executor.layers.conv import Conv2dLayer
2526
from vllm.model_executor.layers.layernorm import RMSNorm
2627
from vllm.model_executor.layers.linear import (
2728
MergedColumnParallelLinear,
@@ -549,7 +550,7 @@ def forward(self, hidden_state: torch.Tensor):
549550
class ChameleonVQVAEEncoderConvDownsample(nn.Module):
550551
def __init__(self, in_channels: int):
551552
super().__init__()
552-
self.conv = nn.Conv2d(
553+
self.conv = Conv2dLayer(
553554
in_channels, in_channels, kernel_size=3, stride=2, padding=0
554555
)
555556

@@ -577,23 +578,23 @@ def __init__(
577578
self.norm1 = torch.nn.GroupNorm(
578579
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
579580
)
580-
self.conv1 = torch.nn.Conv2d(
581+
self.conv1 = Conv2dLayer(
581582
in_channels, out_channels, kernel_size=3, stride=1, padding=1
582583
)
583584
self.norm2 = torch.nn.GroupNorm(
584585
num_groups=32, num_channels=out_channels, eps=1e-6, affine=True
585586
)
586587
self.dropout = torch.nn.Dropout(config.dropout)
587-
self.conv2 = torch.nn.Conv2d(
588+
self.conv2 = Conv2dLayer(
588589
out_channels, out_channels, kernel_size=3, stride=1, padding=1
589590
)
590591
if self.in_channels != self.out_channels:
591592
if self.use_conv_shortcut:
592-
self.conv_shortcut = torch.nn.Conv2d(
593+
self.conv_shortcut = Conv2dLayer(
593594
in_channels, out_channels, kernel_size=3, stride=1, padding=1
594595
)
595596
else:
596-
self.nin_shortcut = torch.nn.Conv2d(
597+
self.nin_shortcut = Conv2dLayer(
597598
in_channels, out_channels, kernel_size=1, stride=1, padding=0
598599
)
599600

@@ -626,16 +627,16 @@ def __init__(self, in_channels: int):
626627
self.norm = torch.nn.GroupNorm(
627628
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
628629
)
629-
self.q = torch.nn.Conv2d(
630+
self.q = Conv2dLayer(
630631
in_channels, in_channels, kernel_size=1, stride=1, padding=0
631632
)
632-
self.k = torch.nn.Conv2d(
633+
self.k = Conv2dLayer(
633634
in_channels, in_channels, kernel_size=1, stride=1, padding=0
634635
)
635-
self.v = torch.nn.Conv2d(
636+
self.v = Conv2dLayer(
636637
in_channels, in_channels, kernel_size=1, stride=1, padding=0
637638
)
638-
self.proj_out = torch.nn.Conv2d(
639+
self.proj_out = Conv2dLayer(
639640
in_channels, in_channels, kernel_size=1, stride=1, padding=0
640641
)
641642

@@ -681,7 +682,7 @@ def __init__(self, config: ChameleonVQVAEConfig):
681682
latent_channels = config.latent_channels
682683
channel_multiplier = config.channel_multiplier
683684

684-
self.conv_in = torch.nn.Conv2d(
685+
self.conv_in = Conv2dLayer(
685686
in_channels, base_channels, kernel_size=3, stride=1, padding=1
686687
)
687688

@@ -738,7 +739,7 @@ def __init__(self, config: ChameleonVQVAEConfig):
738739
self.norm_out = torch.nn.GroupNorm(
739740
num_groups=32, num_channels=block_in, eps=1e-6, affine=True
740741
)
741-
self.conv_out = torch.nn.Conv2d(
742+
self.conv_out = Conv2dLayer(
742743
block_in,
743744
2 * latent_channels if double_latent else latent_channels,
744745
kernel_size=3,
@@ -779,10 +780,8 @@ def __init__(self, config: ChameleonVQVAEConfig):
779780
super().__init__()
780781
self.encoder = ChameleonVQVAEEncoder(config)
781782
self.quantize = ChameleonVQVAEVectorQuantizer(config)
782-
self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1)
783-
self.post_quant_conv = torch.nn.Conv2d(
784-
config.embed_dim, config.latent_channels, 1
785-
)
783+
self.quant_conv = Conv2dLayer(config.latent_channels, config.embed_dim, 1)
784+
self.post_quant_conv = Conv2dLayer(config.embed_dim, config.latent_channels, 1)
786785
self.eval() # Chameleon's VQ model is frozen
787786

788787
def encode(

vllm/model_executor/models/deepencoder.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from transformers import CLIPVisionConfig
2020

2121
from vllm.attention.layer import MultiHeadAttention
22+
from vllm.model_executor.layers.conv import Conv2dLayer
2223
from vllm.model_executor.layers.quantization import QuantizationConfig
2324
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
2425

@@ -133,14 +134,14 @@ def __init__(
133134
self.blocks.append(block)
134135

135136
self.neck = nn.Sequential(
136-
nn.Conv2d(
137+
Conv2dLayer(
137138
embed_dim,
138139
out_chans,
139140
kernel_size=1,
140141
bias=False,
141142
),
142143
LayerNorm2d(out_chans),
143-
nn.Conv2d(
144+
Conv2dLayer(
144145
out_chans,
145146
out_chans,
146147
kernel_size=3,
@@ -150,8 +151,10 @@ def __init__(
150151
LayerNorm2d(out_chans),
151152
)
152153

153-
self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
154-
self.net_3 = nn.Conv2d(
154+
self.net_2 = Conv2dLayer(
155+
256, 512, kernel_size=3, stride=2, padding=1, bias=False
156+
)
157+
self.net_3 = Conv2dLayer(
155158
512, 1024, kernel_size=3, stride=2, padding=1, bias=False
156159
)
157160

@@ -500,7 +503,7 @@ def __init__(
500503
"""
501504
super().__init__()
502505

503-
self.proj = nn.Conv2d(
506+
self.proj = Conv2dLayer(
504507
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
505508
)
506509

vllm/model_executor/models/dots_ocr.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
get_tensor_model_parallel_world_size,
2323
)
2424
from vllm.model_executor.layers.activation import SiluAndMul
25+
from vllm.model_executor.layers.conv import Conv2dLayer
2526
from vllm.model_executor.layers.layernorm import RMSNorm
2627
from vllm.model_executor.layers.linear import (
2728
ColumnParallelLinear,
@@ -471,7 +472,7 @@ def __init__(self, config):
471472
self.temporal_patch_size = config.temporal_patch_size
472473
self.embed_dim = config.embed_dim
473474
self.config = config
474-
self.proj = nn.Conv2d(
475+
self.proj = Conv2dLayer(
475476
config.num_channels,
476477
config.embed_dim,
477478
kernel_size=(config.patch_size, config.patch_size),

vllm/model_executor/models/glm4_1v.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state
5757
from vllm.distributed import utils as dist_utils
5858
from vllm.logger import init_logger
59-
from vllm.model_executor.layers.conv import Conv3dLayer
59+
from vllm.model_executor.layers.conv import Conv2dLayer, Conv3dLayer
6060
from vllm.model_executor.layers.layernorm import RMSNorm
6161
from vllm.model_executor.layers.linear import (
6262
ColumnParallelLinear,
@@ -734,7 +734,7 @@ def __init__(
734734
self.post_conv_layernorm = RMSNorm(
735735
vision_config.hidden_size, eps=vision_config.rms_norm_eps
736736
)
737-
self.downsample = nn.Conv2d(
737+
self.downsample = Conv2dLayer(
738738
in_channels=vision_config.hidden_size,
739739
out_channels=vision_config.out_hidden_size,
740740
kernel_size=vision_config.spatial_merge_size,

vllm/model_executor/models/glm4v.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from vllm.config.multimodal import BaseDummyOptions
2525
from vllm.distributed import get_tensor_model_parallel_world_size
2626
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
27+
from vllm.model_executor.layers.conv import Conv2dLayer
2728
from vllm.model_executor.layers.linear import (
2829
ColumnParallelLinear,
2930
MergedColumnParallelLinear,
@@ -78,7 +79,7 @@ class GLMVImagePixelInputs(TensorSchema):
7879
class EVA2CLIPPatchEmbedding(nn.Module):
7980
def __init__(self, config):
8081
super().__init__()
81-
self.proj = nn.Conv2d(
82+
self.proj = Conv2dLayer(
8283
config.in_channels,
8384
config.hidden_size,
8485
kernel_size=config.patch_size,
@@ -333,7 +334,7 @@ def __init__(
333334
quant_config=quant_config,
334335
prefix=f"{prefix}.linear_proj",
335336
)
336-
self.conv = nn.Conv2d(
337+
self.conv = Conv2dLayer(
337338
in_channels=vision_config.hidden_size,
338339
out_channels=config.hidden_size,
339340
kernel_size=2,

vllm/model_executor/models/idefics2_vision_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from vllm.attention.layer import MultiHeadAttention
3131
from vllm.distributed import get_tensor_model_parallel_world_size
3232
from vllm.model_executor.layers.activation import get_act_fn
33+
from vllm.model_executor.layers.conv import Conv2dLayer
3334
from vllm.model_executor.layers.linear import (
3435
ColumnParallelLinear,
3536
QKVParallelLinear,
@@ -60,7 +61,7 @@ def __init__(self, config: Idefics2VisionConfig):
6061
self.embed_dim = config.hidden_size
6162
self.image_size = config.image_size
6263
self.patch_size = config.patch_size
63-
self.patch_embedding = nn.Conv2d(
64+
self.patch_embedding = Conv2dLayer(
6465
in_channels=config.num_channels,
6566
out_channels=self.embed_dim,
6667
kernel_size=self.patch_size,

vllm/model_executor/models/intern_vit.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
tensor_model_parallel_all_gather,
2525
)
2626
from vllm.model_executor.layers.activation import get_act_fn
27+
from vllm.model_executor.layers.conv import Conv2dLayer
2728
from vllm.model_executor.layers.layernorm import RMSNorm
2829
from vllm.model_executor.layers.linear import (
2930
ColumnParallelLinear,
@@ -51,7 +52,7 @@ def __init__(self, config: PretrainedConfig):
5152

5253
self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
5354

54-
self.patch_embedding = nn.Conv2d(
55+
self.patch_embedding = Conv2dLayer(
5556
in_channels=3,
5657
out_channels=self.embed_dim,
5758
kernel_size=self.patch_size,

0 commit comments

Comments
 (0)