Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit cfeec04

Browse files
committed
Merge branch 'main' into text-only-11b
2 parents aa5fc54 + ae3555b commit cfeec04

File tree

2 files changed

+63
-9
lines changed

2 files changed

+63
-9
lines changed

README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ The following assumes you've completed the steps for [Setting up ExecuTorch](#se
449449

450450
- [executorch-240919.aar](https://ossci-android.s3.amazonaws.com/executorch/main/executorch-240919.aar) (SHASUM: c8a5d38ead03bfa28ee8469f6355840ad0d182ba)
451451

452-
2. Rename the downloaded AAR file to `executorch.aar` and move the file to `android/torchchat/app/libs/`. You may need to create directory `android/torchchat/app/libs/` if it does not exist.
452+
2. Rename the downloaded AAR file to `executorch.aar` and move the file to `torchchat/edge/android/torchchat/app/libs/`. You may need to create directory `torchchat/edge/android/torchchat/app/libs/` if it does not exist.
453453

454454
3. Push the model and tokenizer file to your device. You can find the model file called `llama3.1.pte` in the current `torchchat` directory and the tokenizer file at `$(python3 torchchat.py where llama3.1)/tokenizer.model` path.
455455
```
@@ -484,7 +484,6 @@ Alternatively, you can run `torchchat/utils/scripts/android_example.sh` which se
484484

485485
```
486486
export TORCHCHAT_ROOT=$(pwd)
487-
export USE_TIKTOKEN=ON # Set this only for tiktoken tokenizer
488487
sh torchchat/utils/scripts/android_example.sh
489488
```
490489

torchchat/usages/openai_api.py

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919

2020
from PIL import Image
2121

22-
from torchtune.data import Message, padded_collate
22+
from torchtune.data import Message, padded_collate_tiled_images_and_mask
23+
2324
from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform
2425

2526
from torchchat.cli.download import is_model_downloaded, load_model_configs
@@ -288,13 +289,50 @@ def __init__(self, *args, **kwargs):
288289
else self.model.text_transformer_args.max_seq_length
289290
)
290291
except:
291-
# can not find max_seq_length in model config, use default value
292-
self.max_seq_length = 128
292+
self.max_seq_length = 2048
293+
print(f"can not find max_seq_length in model config, use default value: {self.max_seq_length}")
293294
# The System fingerprint is a unique identifier for the model and its configuration.
294295
self.system_fingerprint = (
295296
f"{self.builder_args.device}_{self.builder_args.precision}"
296297
)
297298

299+
def _openai_messages_to_torchtune_messages(
300+
self, messages: List[_AbstractMessage]
301+
) -> List[Message]:
302+
"""Convert a list of OpenAI API messages to a list of TorchTune messages.
303+
304+
Args:
305+
messages: A list of OpenAI API messages.
306+
307+
Returns:
308+
A list of Torchtune Messages.
309+
"""
310+
torchtune_messages = []
311+
for message in messages:
312+
torchtune_contents = []
313+
if isinstance(message["content"], list):
314+
for content_dict in message["content"]:
315+
converted_content = []
316+
if content_dict["type"] == "text":
317+
converted_content.append(
318+
{"type": "text", "content": content_dict["text"]}
319+
)
320+
elif content_dict["type"] == "image_url":
321+
base64_decoded = base64.b64decode(
322+
content_dict["image_url"].split(";base64,")[1]
323+
)
324+
image = Image.open(BytesIO(base64_decoded))
325+
converted_content.append(
326+
{
327+
"type": "image",
328+
"content": image,
329+
}
330+
)
331+
torchtune_messages.append(
332+
Message(role=message["role"], content=converted_content, eot=False)
333+
)
334+
return torchtune_messages
335+
298336
def _openai_messages_to_torchtune(
299337
self, messages: List[_AbstractMessage]
300338
) -> List[Message]:
@@ -376,15 +414,32 @@ def chunked_completion(self, completion_request: CompletionRequest):
376414
transform = llama3_2_vision_transform(
377415
str(self.tokenizer_args.tokenizer_path)
378416
)
379-
torchtune_messages = self._openai_messages_to_torchtune(
417+
torchtune_messages = self._openai_messages_to_torchtune_messages(
380418
completion_request.messages
381419
)
382420
data = transform(
383421
{"images": images, "messages": torchtune_messages}, inference=True
384422
)
385-
batch = padded_collate([data], self.builder_args.device)
386-
batch.pop("mask")
387-
encoded = batch["tokens"]
423+
seq_len = len(data["tokens"])
424+
total_response_length = seq_len + completion_request.max_tokens
425+
causal_mask = torch.tril(
426+
torch.ones(
427+
size=(total_response_length, total_response_length),
428+
dtype=torch.bool,
429+
)
430+
)
431+
input_pos = torch.arange(total_response_length)
432+
433+
with torch.no_grad():
434+
with torch.device(self.builder_args.device):
435+
batch = padded_collate_tiled_images_and_mask([data], pad_direction="left", pad_max_images=1)
436+
batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(self.builder_args.precision)
437+
batch["causal_mask"] = causal_mask
438+
batch["input_pos"] = input_pos[None, :seq_len]
439+
batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len]
440+
441+
#batch = padded_collate([data], self.builder_args.device)
442+
encoded = batch["tokens"].view(-1)
388443
else:
389444
tokens = self.chat_formatter.encode_dialog_prompt(
390445
dialog=[

0 commit comments

Comments
 (0)