Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 728fc46

Browse files
committed
reformat llava
1 parent 353fafe commit 728fc46

File tree

1 file changed

+40
-18
lines changed

1 file changed

+40
-18
lines changed

torchchat/model.py

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,19 @@
1-
# Copyright (c) Meta Platforms, Inc. and affiliates.
2-
# All rights reserved.
3-
4-
# This source code is licensed under the license found in the
5-
# LICENSE file in the root directory of this source tree.
61
import json
72
import os
83
import warnings
4+
from abc import ABC, abstractmethod
95

106
from dataclasses import dataclass
117
from enum import Enum
128
from pathlib import Path
139

1410
from typing import Any, Callable, Dict, Optional, Union
15-
from abc import ABC, abstractmethod
1611

1712
import torch
1813
import torch.nn as nn
1914

2015
from torch import Tensor
21-
from torch.distributed._tensor import Replicate, Shard, DTensor
16+
from torch.distributed._tensor import DTensor, Replicate, Shard
2217
from torch.distributed.device_mesh import DeviceMesh
2318
from torch.distributed.tensor.parallel import (
2419
ColwiseParallel,
@@ -31,11 +26,12 @@
3126
from torchchat.utils.build_utils import find_multiple, get_precision
3227

3328
from torchtune.models.flamingo import flamingo_decoder, flamingo_vision_encoder
34-
from torchtune.modules.model_fusion import DeepFusionModel
3529
from torchtune.models.llama3_1._component_builders import llama3_1 as llama3_1_builder
30+
from torchtune.modules.model_fusion import DeepFusionModel
3631

3732
config_path = Path(f"{str(Path(__file__).parent)}/model_params")
3833

34+
3935
def identity(**kwargs):
4036
if len(kwargs) != 1:
4137
raise ValueError("Only one argument is expected")
@@ -56,8 +52,17 @@ def forward(self, image_features):
5652
hidden_states = self.linear_2(hidden_states)
5753
return hidden_states
5854

55+
5956
class ConcateFusion(nn.Module):
60-
def __init__(self, encoder: nn.Module, decoder: nn.Module, token_embedding_name="tok_embeddings", mm_proj_in_channels=1024, mm_proj_out_channels=4096, mm_proj_activation=nn.GELU):
57+
def __init__(
58+
self,
59+
encoder: nn.Module,
60+
decoder: nn.Module,
61+
token_embedding_name="tok_embeddings",
62+
mm_proj_in_channels=1024,
63+
mm_proj_out_channels=4096,
64+
mm_proj_activation=nn.GELU,
65+
):
6166
super().__init__()
6267
self.encoder = encoder
6368
self.decoder = decoder
@@ -67,36 +72,53 @@ def __init__(self, encoder: nn.Module, decoder: nn.Module, token_embedding_name=
6772
self.tok_embeddings = getattr(self.decoder, token_embedding_name)
6873

6974
# set the embedding layer in decoder to None to jump the embedding layer over in decoder
70-
self.decoder.__setattr__(token_embedding_name) = None
75+
self.decoder.__setattr__(token_embedding_name, None)
7176

72-
self.mm_projector = MultiModalProjector(ProjectorArgs(in_channels=mm_proj_in_channels, out_channels=mm_proj_out_channels, activation=mm_proj_activation))
77+
self.mm_projector = MultiModalProjector(
78+
ProjectorArgs(
79+
in_channels=mm_proj_in_channels,
80+
out_channels=mm_proj_out_channels,
81+
activation=mm_proj_activation,
82+
)
83+
)
7384

74-
def forward(self,
85+
def forward(
86+
self,
7587
tokens: Tensor,
7688
*,
7789
post_tokens: Optional[Tensor] = None,
7890
encoder_input: Optional[Tensor] = None,
7991
encoder_mask: Optional[torch.Tensor] = None,
80-
input_pos: Optional[torch.Tensor] = None,) -> Tensor:
92+
input_pos: Optional[torch.Tensor] = None,
93+
) -> Tensor:
8194
if encoder_input:
8295
encoder_output = self.encoder(
8396
encoder_input,
8497
)
8598
else:
8699
encoder_output = None
87-
88-
decoder_input = self._get_decoder_input(tokens, encoder_input=encoder_input, post_tokens=post_tokens)
100+
101+
decoder_input = self._get_decoder_input(
102+
tokens, encoder_input=encoder_input, post_tokens=post_tokens
103+
)
89104
return self.decoder(decoder_input)
90105

91-
def _get_decoder_input(self, tokens: Tensor, *, encoder_input: Optional[Tensor], post_tokens: Optional[Tensor]):
92-
assert bool(encoder_input) == bool(post_tokens), "encoder_input and post_tokens must be both None or not None"
106+
def _get_decoder_input(
107+
self,
108+
tokens: Tensor,
109+
*,
110+
encoder_input: Optional[Tensor],
111+
post_tokens: Optional[Tensor],
112+
):
113+
assert bool(encoder_input) == bool(
114+
post_tokens
115+
), "encoder_input and post_tokens must be both None or not None"
93116
if encoder_input is None:
94117
return self.tok_embeddings(tokens)
95118
else:
96119
pre_img_embed = self.tok_embeddings(tokens)
97120
post_img_embed = self.tok_embeddings(post_tokens)
98121
return torch.cat((pre_img_embed, image_embeds, post_img_embed), dim=1)
99-
100122

101123

102124

0 commit comments

Comments
 (0)