Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 7ffec73

Browse files
committed
solve comments
1 parent 7aab3b4 commit 7ffec73

File tree

1 file changed

+21
-35
lines changed

1 file changed

+21
-35
lines changed

torchchat/model.py

Lines changed: 21 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
from dataclasses import dataclass
1212
from enum import Enum
1313
from pathlib import Path
14-
from PIL import Image
15-
import requests
14+
1615
import torchvision
1716

1817
from typing import Any, Callable, Dict, Optional, Union
@@ -35,14 +34,10 @@
3534
from torchtune.models.flamingo import flamingo_decoder, flamingo_vision_encoder
3635
from torchtune.models.llama3_1._component_builders import llama3_1 as llama3_1_builder
3736
from torchtune.modules.model_fusion import DeepFusionModel
37+
from torchtune.models.clip import clip_vision_encoder
3838

3939
from torchchat.utils.build_utils import find_multiple, get_precision
4040

41-
from torchtune.models.flamingo import flamingo_decoder, flamingo_vision_encoder
42-
from torchtune.models.llama3_1._component_builders import llama3_1 as llama3_1_builder
43-
from torchtune.modules.model_fusion import DeepFusionModel
44-
from torchtune.models.clip import clip_vision_encoder
45-
4641
config_path = Path(f"{str(Path(__file__).parent)}/model_params")
4742

4843

@@ -61,19 +56,13 @@ def identity(**kwargs):
6156
return list(kwargs.values())[0]
6257

6358

64-
@dataclass
65-
class ProjectorArgs:
66-
in_channels: int = 1024
67-
out_channels: int = 4096
68-
activation: nn.Module = nn.GELU()
69-
7059

7160
class MultiModalProjector(nn.Module):
72-
def __init__(self, args: ProjectorArgs):
61+
def __init__(self, in_channels: int, out_channels: int, act: nn.Module):
7362
super().__init__()
7463

7564
self.linear_1 = nn.Linear(args.in_channels, args.out_channels, bias=True)
76-
self.act = args.activation
65+
self.act = act
7766
self.linear_2 = nn.Linear(args.out_channels, args.out_channels, bias=True)
7867

7968
def forward(self, image_features):
@@ -105,11 +94,9 @@ def __init__(
10594
self.decoder.__setattr__(token_embedding_name, None)
10695

10796
self.mm_projector = MultiModalProjector(
108-
ProjectorArgs(
10997
in_channels=mm_proj_in_channels,
11098
out_channels=mm_proj_out_channels,
111-
activation=mm_proj_activation,
112-
)
99+
act=mm_proj_activation,
113100
)
114101

115102
def forward(
@@ -123,9 +110,7 @@ def forward(
123110
) -> Tensor:
124111
if encoder_input is not None:
125112
encoder_input = encoder_input.view(1, 1, *encoder_input.shape)
126-
encoder_output = self.encoder(
127-
encoder_input,
128-
)
113+
encoder_output = self.encoder(encoder_input)
129114
encoder_output = self._encoder_feature_select(encoder_output)
130115
else:
131116
encoder_output = None
@@ -143,10 +128,10 @@ def forward(
143128

144129
return self.decoder(decoder_input, input_pos=input_pos)
145130

146-
def setup_caches(self, batch_size, max_seq_len):
131+
def setup_caches(self, batch_size, max_seq_len) -> None:
147132
self.decoder.setup_caches(batch_size, max_seq_len)
148133

149-
def _encoder_feature_select(self, encoder_output):
134+
def _encoder_feature_select(self, encoder_output) -> Tensor:
150135
selected_image_feature = encoder_output[1][0].view(
151136
*encoder_output[1][0].shape[2:]
152137
)
@@ -160,7 +145,7 @@ def _get_decoder_input(
160145
*,
161146
encoder_output: Optional[Tensor],
162147
post_tokens: Optional[Tensor],
163-
):
148+
) -> Tensor:
164149
if encoder_output is None:
165150
assert post_tokens is None
166151
return self.tok_embeddings(tokens)
@@ -245,16 +230,17 @@ def _llava(cls):
245230

246231
@classmethod
247232
def get_recipe(cls, model_type):
248-
if model_type == ModelType.TextOnly:
249-
return cls._text_only()
250-
elif model_type == ModelType.Flamingo:
251-
return cls._flamingo()
252-
elif model_type == ModelType.Llama3_1:
253-
return cls._llama3_1()
254-
elif model_type == ModelType.Llava:
255-
return cls._llava()
256-
else:
257-
raise ValueError(f"Can not find the model recipe for {model_type}")
233+
match model_type:
234+
case ModelType.TextOnly:
235+
return cls._text_only()
236+
case ModelType.Flamingo:
237+
return cls._flamingo()
238+
case ModelType.Llama3_1:
239+
return cls._llama3_1()
240+
case ModelType.Llava:
241+
return cls._llava()
242+
case _:
243+
raise ValueError(f"Can not find the model recipe for {model_type}")
258244

259245

260246
@dataclass
@@ -475,7 +461,7 @@ def build_model(self) -> nn.Module:
475461

476462
return recipe.fusion_class(**modules)
477463

478-
def _replace_know_params(self, params):
464+
def _replace_known_params(self, params):
479465
patterns = {"QuickGELUActivation()": QuickGELUActivation()}
480466
for key, value in params.items():
481467
if isinstance(value, Hashable) and value in patterns:

0 commit comments

Comments
 (0)