Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,7 @@ def chat(
if text_transformer_args is not None
else 2048
),
max_seq_length
)

max_seq_length = (
Expand Down
26 changes: 17 additions & 9 deletions torchchat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,10 @@ def forward(
)

return self.decoder(decoder_input, input_pos=input_pos)

def setup_caches(self, batch_size, max_seq_len) -> None:
self.decoder.setup_caches(batch_size, max_seq_len)

def _encoder_feature_select(self, encoder_output) -> Tensor:
selected_image_feature = encoder_output[1][0].view(
*encoder_output[1][0].shape[2:]
Expand All @@ -154,7 +154,7 @@ def _get_decoder_input(
image_embeds = self.mm_projector(encoder_output)
if post_tokens is None:
return torch.cat((pre_img_embed, image_embeds), dim=1)

post_img_embed = self.tok_embeddings(post_tokens)
return torch.cat((pre_img_embed, image_embeds, post_img_embed), dim=1)

Expand Down Expand Up @@ -227,7 +227,7 @@ def _llava(cls):
},
fusion_class=ConcateFusion,
)

@classmethod
def get_recipe(cls, model_type):
match model_type:
Expand Down Expand Up @@ -338,7 +338,7 @@ def _sanity_check(
def from_params(cls, params_path):
with open(params_path, "r") as f:
loaded_params = json.loads(f.read())

if (model_type_name := loaded_params.get("model_type", None)) is None:
# The model params is in the transformer_args format
# set the model_type to TextOnly and reformat the params
Expand Down Expand Up @@ -460,14 +460,14 @@ def build_model(self) -> nn.Module:
modules[name] = module_class(**config_args)

return recipe.fusion_class(**modules)

def _replace_known_params(self, params):
patterns = {"QuickGELUActivation()": QuickGELUActivation()}
for key, value in params.items():
if isinstance(value, Hashable) and value in patterns:
params[key] = patterns[value]
return params

@abstractmethod
def forward(self, *args, **kwargs):
raise NotImplementedError("forward method is not implemented")
Expand Down Expand Up @@ -939,7 +939,15 @@ def __init__(self, config, path) -> None:
self.model_ = exec_lib._load_for_executorch(str(path))

self.text_transformer_args = TransformerArgs.from_params(self.config.transformer_args["text"])

# TODO: attempt to use "get_max_seq_len" method on the model after
# ExecuTorch bug is fixed.
max_seq_len = 128
# try:
# max_seq_len = self.model_.run_method("get_max_seq_len", [])
# except Exception as e:
# pass
self.text_transformer_args.max_seq_length = max_seq_len

def forward(self, x, input_pos):
# model_.forward expects inputs to be wrapped in a tuple
forward_inputs = (x.to(torch.long), input_pos.to(torch.long))
Expand All @@ -958,6 +966,6 @@ def forward(self, x, input_pos):

def setup_caches(self, max_batch_size, max_seq_length):
pass

except:
pass
Loading