Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Conversation

@Jack-Khuu
Copy link
Contributor

@Jack-Khuu Jack-Khuu commented Oct 1, 2024

Refactor the OpenAI parsing logic for non-MM dialog into _gen_model_input of generate.py


Tested via Browser

python torchchat.py server llama3.2-11B 
python torchchat.py server llama3.2-1B 

Then test in browser with 1B (Text only - Saw it can do multiturn) and 11B (Multimodal - No multiturn as expected)

 streamlit run torchchat/usages/browser.py

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 1, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1248

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 29f5204 with merge base edaa15c (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 1, 2024
Comment on lines +783 to +832
assert (
max_new_tokens is not None
), "max_new_tokens must be specified for Flamingo models"
assert isinstance(
prompt, str
), "(Currently) prompt must be a str for Flamingo models"

is_multimodal = images is not None
content = [{"type": "text", "content": prompt}]
is_multimodal = images is not None
content = [{"type": "text", "content": prompt}]

if is_multimodal:
content = [{"type": "image", "content": images[0]}] + content
if is_multimodal:
content = [{"type": "image", "content": images[0]}] + content

messages = [
Message(
role="user",
content=content,
eot=True,
),
Message(role="assistant", content=""),
]
messages = [
Message(
role="user",
content=content,
eot=True,
),
Message(role="assistant", content=""),
]

transform = llama3_2_vision_transform(str(self.tokenizer_args.tokenizer_path))
transform = llama3_2_vision_transform(str(self.tokenizer_args.tokenizer_path))

device = torch.device(device=self.builder_args.device)
device = torch.device(device=self.builder_args.device)

with device, set_default_dtype(self.dtype):
data = transform({"messages": messages}, inference=True)
with device, set_default_dtype(self.dtype):
data = transform({"messages": messages}, inference=True)

if is_multimodal:
batch = padded_collate_tiled_images_and_mask(
[data], pad_direction="left", pad_max_images=1
)
encoded = batch.pop("tokens").to(device).view(-1)
seq_len = encoded.size(0)
batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len]
batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(self.dtype)
else:
encoded = torch.tensor(
data["tokens"], device=device
).view(-1)
seq_len = encoded.size(0)
batch = {}

total_response_length = seq_len + max_new_tokens
batch["causal_mask"] = torch.tril(
torch.ones(
size=(total_response_length, total_response_length),
dtype=torch.bool,
)
)
else:
encoded = self.encode_tokens(
prompt, bos=True, device=self.builder_args.device
if is_multimodal:
batch = padded_collate_tiled_images_and_mask(
[data], pad_direction="left", pad_max_images=1
)
encoded = batch.pop("tokens").to(device).view(-1)
seq_len = encoded.size(0)
batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len]
batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(
self.dtype
)
else:
encoded = torch.tensor(data["tokens"], device=device).view(-1)
seq_len = encoded.size(0)
batch = {}

total_response_length = seq_len + max_new_tokens
batch["causal_mask"] = torch.tril(
torch.ones(
size=(total_response_length, total_response_length),
dtype=torch.bool,
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lint and white space

@Jack-Khuu Jack-Khuu requested review from byjlw and vmpuri October 1, 2024 21:37
@Jack-Khuu Jack-Khuu marked this pull request as ready for review October 1, 2024 21:38
@Jack-Khuu Jack-Khuu merged commit 58185b6 into main Oct 2, 2024
52 checks passed
@Jack-Khuu Jack-Khuu deleted the reunify-request-generation branch October 5, 2024 02:37
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants