|
22 | 22 | from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size |
23 | 23 | from vllm.logger import init_logger |
24 | 24 | from vllm.model_executor.layers.activation import SiluAndMul |
| 25 | +from vllm.model_executor.layers.conv import Conv2dLayer |
25 | 26 | from vllm.model_executor.layers.layernorm import RMSNorm |
26 | 27 | from vllm.model_executor.layers.linear import ( |
27 | 28 | MergedColumnParallelLinear, |
@@ -549,7 +550,7 @@ def forward(self, hidden_state: torch.Tensor): |
549 | 550 | class ChameleonVQVAEEncoderConvDownsample(nn.Module): |
550 | 551 | def __init__(self, in_channels: int): |
551 | 552 | super().__init__() |
552 | | - self.conv = nn.Conv2d( |
| 553 | + self.conv = Conv2dLayer( |
553 | 554 | in_channels, in_channels, kernel_size=3, stride=2, padding=0 |
554 | 555 | ) |
555 | 556 |
|
@@ -577,23 +578,23 @@ def __init__( |
577 | 578 | self.norm1 = torch.nn.GroupNorm( |
578 | 579 | num_groups=32, num_channels=in_channels, eps=1e-6, affine=True |
579 | 580 | ) |
580 | | - self.conv1 = torch.nn.Conv2d( |
| 581 | + self.conv1 = Conv2dLayer( |
581 | 582 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 |
582 | 583 | ) |
583 | 584 | self.norm2 = torch.nn.GroupNorm( |
584 | 585 | num_groups=32, num_channels=out_channels, eps=1e-6, affine=True |
585 | 586 | ) |
586 | 587 | self.dropout = torch.nn.Dropout(config.dropout) |
587 | | - self.conv2 = torch.nn.Conv2d( |
| 588 | + self.conv2 = Conv2dLayer( |
588 | 589 | out_channels, out_channels, kernel_size=3, stride=1, padding=1 |
589 | 590 | ) |
590 | 591 | if self.in_channels != self.out_channels: |
591 | 592 | if self.use_conv_shortcut: |
592 | | - self.conv_shortcut = torch.nn.Conv2d( |
| 593 | + self.conv_shortcut = Conv2dLayer( |
593 | 594 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 |
594 | 595 | ) |
595 | 596 | else: |
596 | | - self.nin_shortcut = torch.nn.Conv2d( |
| 597 | + self.nin_shortcut = Conv2dLayer( |
597 | 598 | in_channels, out_channels, kernel_size=1, stride=1, padding=0 |
598 | 599 | ) |
599 | 600 |
|
@@ -626,16 +627,16 @@ def __init__(self, in_channels: int): |
626 | 627 | self.norm = torch.nn.GroupNorm( |
627 | 628 | num_groups=32, num_channels=in_channels, eps=1e-6, affine=True |
628 | 629 | ) |
629 | | - self.q = torch.nn.Conv2d( |
| 630 | + self.q = Conv2dLayer( |
630 | 631 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
631 | 632 | ) |
632 | | - self.k = torch.nn.Conv2d( |
| 633 | + self.k = Conv2dLayer( |
633 | 634 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
634 | 635 | ) |
635 | | - self.v = torch.nn.Conv2d( |
| 636 | + self.v = Conv2dLayer( |
636 | 637 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
637 | 638 | ) |
638 | | - self.proj_out = torch.nn.Conv2d( |
| 639 | + self.proj_out = Conv2dLayer( |
639 | 640 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
640 | 641 | ) |
641 | 642 |
|
@@ -681,7 +682,7 @@ def __init__(self, config: ChameleonVQVAEConfig): |
681 | 682 | latent_channels = config.latent_channels |
682 | 683 | channel_multiplier = config.channel_multiplier |
683 | 684 |
|
684 | | - self.conv_in = torch.nn.Conv2d( |
| 685 | + self.conv_in = Conv2dLayer( |
685 | 686 | in_channels, base_channels, kernel_size=3, stride=1, padding=1 |
686 | 687 | ) |
687 | 688 |
|
@@ -738,7 +739,7 @@ def __init__(self, config: ChameleonVQVAEConfig): |
738 | 739 | self.norm_out = torch.nn.GroupNorm( |
739 | 740 | num_groups=32, num_channels=block_in, eps=1e-6, affine=True |
740 | 741 | ) |
741 | | - self.conv_out = torch.nn.Conv2d( |
| 742 | + self.conv_out = Conv2dLayer( |
742 | 743 | block_in, |
743 | 744 | 2 * latent_channels if double_latent else latent_channels, |
744 | 745 | kernel_size=3, |
@@ -779,10 +780,8 @@ def __init__(self, config: ChameleonVQVAEConfig): |
779 | 780 | super().__init__() |
780 | 781 | self.encoder = ChameleonVQVAEEncoder(config) |
781 | 782 | self.quantize = ChameleonVQVAEVectorQuantizer(config) |
782 | | - self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1) |
783 | | - self.post_quant_conv = torch.nn.Conv2d( |
784 | | - config.embed_dim, config.latent_channels, 1 |
785 | | - ) |
| 783 | + self.quant_conv = Conv2dLayer(config.latent_channels, config.embed_dim, 1) |
| 784 | + self.post_quant_conv = Conv2dLayer(config.embed_dim, config.latent_channels, 1) |
786 | 785 | self.eval() # Chameleon's VQ model is frozen |
787 | 786 |
|
788 | 787 | def encode( |
|
0 commit comments