Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 128566c

Browse files
committed
prepare for rebase
1 parent 83f8501 commit 128566c

File tree

1 file changed

+8
-95
lines changed

1 file changed

+8
-95
lines changed

torchchat/model.py

Lines changed: 8 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,14 @@ def forward(
128128
decoder_input = self._get_decoder_input(
129129
tokens, encoder_output=encoder_output, post_tokens=post_tokens
130130
)
131+
132+
if input_pos is None:
133+
input_pos = torch.arange(
134+
decoder_input.shape[1],
135+
device=decoder_input.device,
136+
dtype=torch.int,
137+
)
138+
131139
return self.decoder(decoder_input, input_pos=input_pos)
132140

133141
def setup_caches(self, batch_size, max_seq_len):
@@ -977,98 +985,3 @@ def setup_caches(self, max_batch_size, max_seq_length):
977985

978986
except:
979987
pass
980-
981-
982-
if __name__ == "__main__":
983-
def prepare_image(target_h: int, target_w: int) -> torch.Tensor:
984-
"""Read image into a tensor and resize the image so that it fits in
985-
a target_h x target_w canvas.
986-
987-
Args:
988-
image (Image): An Image object.
989-
target_h (int): Target height.
990-
target_w (int): Target width.
991-
992-
Returns:
993-
torch.Tensor: resized image tensor.
994-
"""
995-
image = Image.open(
996-
requests.get(
997-
"https://llava-vl.github.io/static/images/view.jpg", stream=True
998-
).raw)
999-
1000-
img = torchvision.transforms.functional.pil_to_tensor(image)
1001-
# height ratio
1002-
ratio_h = img.shape[1] / target_h
1003-
# width ratio
1004-
ratio_w = img.shape[2] / target_w
1005-
# resize the image so that it fits in a target_h x target_w canvas
1006-
ratio = max(ratio_h, ratio_w)
1007-
output_size = (int(img.shape[1] / ratio), int(img.shape[2] / ratio))
1008-
img = torchvision.transforms.Resize(size=output_size)(img)
1009-
return img
1010-
1011-
1012-
def image_preprocess(img: torch.Tensor, target_h: int, target_w: int, rescale_factor, image_mean, image_std) -> torch.Tensor:
1013-
# pad the image with median rgb value, to make a square
1014-
l_pad = (target_w - img.shape[2]) // 2
1015-
t_pad = (target_h - img.shape[1]) // 2
1016-
# ceil division
1017-
r_pad = -((target_w - img.shape[2]) // -2)
1018-
b_pad = -((target_h - img.shape[1]) // -2)
1019-
1020-
torch._check(l_pad >= 0)
1021-
torch._check(t_pad >= 0)
1022-
torch._check(r_pad >= 0)
1023-
torch._check(b_pad >= 0)
1024-
1025-
# This is different from the original implementation, due to export limitations.
1026-
resized = torch.nn.functional.pad(
1027-
img,
1028-
(l_pad, r_pad, t_pad, b_pad),
1029-
)
1030-
# originally:
1031-
# resized = F.pad(
1032-
# img,
1033-
# padding=(l_pad, t_pad, r_pad, b_pad),
1034-
# fill=tuple(int(x * 255) for x in self.image_mean),
1035-
# )
1036-
1037-
# TODO: implement _upsample_bicubic_aa.out in portable kernel library.
1038-
# here padded shape should be max(h, w) x max(h, w)
1039-
# skipping resize for now due to missing _upsample_bicubic_aa kernel in portable
1040-
# resized = resize(
1041-
# padded,
1042-
# size=[
1043-
# self.image_processor.crop_size["height"],
1044-
# self.image_processor.crop_size["width"],
1045-
# ],
1046-
# interpolation="bicubic",
1047-
# )
1048-
# torch._check(resized.size(1) == self.config.crop_size["height"])
1049-
# torch._check(resized.size(2) == self.config.crop_size["width"])
1050-
# print(resized.shape)
1051-
# cropped = F.center_crop(img, output_size=[w, w])
1052-
# print(cropped.shape)
1053-
scaled = resized * rescale_factor
1054-
# print(scaled)
1055-
from torchvision.transforms.v2 import functional as tvF
1056-
normed = tvF.normalize(
1057-
scaled, image_mean, image_std
1058-
)
1059-
# print(normed)
1060-
return normed.unsqueeze(0)
1061-
1062-
pre_tokens = torch.tensor([[ 1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116,
1063-
21082, 20255, 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892,
1064-
322, 1248, 568, 6089, 304, 278, 5199, 29915, 29879, 5155,
1065-
29889, 3148, 1001, 29901, 29871]])
1066-
img = prepare_image(336, 336)
1067-
post_tokens = torch.tensor([[29871, 13, 462, 9651, 1724, 526, 278, 2712, 306, 881,
1068-
367, 274, 1300, 2738, 1048, 746, 306, 6493, 1244, 29973,
1069-
319, 1799, 9047, 13566, 29901]])
1070-
1071-
llava_model = Model.from_params("/home/gasoonjia/torchchat/torchchat/model_params/llava-1.5.json")
1072-
llava_model.setup_caches(1, 2048)
1073-
img = image_preprocess(img=img, target_h=336, target_w=336, image_mean=[0.48145466, 0.4578275, 0.40821073], image_std=[0.26862954, 0.26130258, 0.27577711], rescale_factor=0.00392156862745098)
1074-
llava_model(tokens=pre_tokens, encoder_input=img, post_tokens=post_tokens)

0 commit comments

Comments
 (0)