diff --git a/torchchat/generate.py b/torchchat/generate.py index 0a22f9160..8f4a0aa99 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -4,6 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import argparse +import base64 import itertools import logging import os @@ -12,6 +13,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass +from io import BytesIO from os import PathLike from pathlib import Path from typing import Any, Dict, List, Optional, Sequence, Tuple, Union @@ -600,9 +602,7 @@ def generate( if len(prompt.shape) > 1: prompt = prompt.squeeze(0) - T = prompt.size(0) - max_new_tokens = min(max_new_tokens, max_seq_length - start_pos - T) - T_new = T + max_new_tokens + prompt_length = prompt.size(0) # set up caches only if first inference if start_pos == 0: model = model.to(device=device) @@ -616,7 +616,7 @@ def generate( batch_size=1, dtype=self.dtype, encoder_max_seq_len=6404, - decoder_max_seq_len=T_new, + decoder_max_seq_len=max_seq_length, ) else: model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) @@ -629,7 +629,7 @@ def generate( model.reset_caches() input_pos = torch.arange( - start_pos, T + start_pos, device=device, dtype=torch.int + start_pos, prompt_length + start_pos, device=device, dtype=torch.int ) prefill_t0 = time.perf_counter() @@ -655,7 +655,7 @@ def generate( # max_new_tokens <= 2 means we are effectively not calling decode_n_tokens(). callback(next_token.clone().view(-1), done_generating=max_new_tokens <= 2) - input_pos = torch.tensor([start_pos + T], device=device, dtype=torch.int) + input_pos = torch.tensor([start_pos + prompt_length], device=device, dtype=torch.int) accept_counts = [0] * ( speculate_k + 1 ) # creates array of [0, 0, 0, ...] that is speculate_k + 1 long @@ -678,7 +678,7 @@ def generate( ) accept_counts[len(next_tokens) - 1] += 1 - num_added = min(T_new - input_pos - 1, len(next_tokens)) + num_added = min(max_new_tokens - input_pos - 1, len(next_tokens)) for token in next_tokens[:num_added,]: callback(token) yield token, None @@ -734,13 +734,14 @@ def _callback(self, x, *, buffer, done_generating): if len(buffer) == 4 or done_generating: print("".join(buffer), end="", flush=True) buffer.clear() - # print(, end='', flush=True) + print(, end='', flush=True) def _gen_model_input( self, prompt: Union[str | List[Any]], image_prompts: Optional[List[str | Image.Image]] = None, max_new_tokens: Optional[int] = None, + max_seq_len: Optional[int] = 2048, ) -> Tuple[torch.Tensor, Optional[Dict[str, Any]]]: """ Convert prompt and image prompts into consumable model input args. @@ -757,7 +758,7 @@ def _gen_model_input( Tuple[torch.Tensor, Optional[Dict[str, Any]]]: Encoded prompt and batch config for multimodal models. """ - # Not Llama 3.2 11B + # Text-Only model if self.model.config.model_type != ModelType.Flamingo: # Single String prompt if isinstance(prompt, str): @@ -778,6 +779,7 @@ def _gen_model_input( assert ( image_prompts is None or len(image_prompts) == 1 ), "At most one image is supported at the moment" + if image_prompts and isinstance(image_prompts[0], str): images = [Image.open(image_prompts[0])] else: @@ -786,24 +788,41 @@ def _gen_model_input( assert ( max_new_tokens is not None ), "max_new_tokens must be specified for Flamingo models" - assert isinstance( - prompt, str - ), "(Currently) prompt must be a str for Flamingo models" - - is_multimodal = images is not None - content = [{"type": "text", "content": prompt}] - if is_multimodal: - content = [{"type": "image", "content": images[0]}] + content - - messages = [ - Message( - role="user", - content=content, - eot=True, - ), - Message(role="assistant", content=""), - ] + image_found = False + messages = [] + for message in prompt: + if isinstance(message["content"], str): + messages.append(Message(**message)) + + elif isinstance(message["content"], list): + images = None + for content_dict in message["content"]: + if content_dict["type"] == "text": + prompt_arg = content_dict["text"] + elif content_dict["type"] == "image_url": + assert ( + images is None + ), "At most one image is supported at the moment" + + base64_decoded = base64.b64decode( + content_dict["image_url"].split(";base64,")[1] + ) + images = [Image.open(BytesIO(base64_decoded))] + image_found = True + + is_multimodal = images is not None + content = [{"type": "text", "content": prompt_arg}] + [] + if is_multimodal: + content = [{"type": "image", "content": images[0]}] + content + + messages.append( + Message( + role="user", + content=content, + ) + ) transform = llama3_2_vision_transform(str(self.tokenizer_args.tokenizer_path)) @@ -812,27 +831,37 @@ def _gen_model_input( with device, set_default_dtype(self.dtype): data = transform({"messages": messages}, inference=True) - if is_multimodal: + if image_found: batch = padded_collate_tiled_images_and_mask( [data], pad_direction="left", pad_max_images=1 ) encoded = batch.pop("tokens").to(device).view(-1) - seq_len = encoded.size(0) + seq_len = encoded.size(0) batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len] batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to( self.dtype ) + else: encoded = torch.tensor(data["tokens"], device=device).view(-1) seq_len = encoded.size(0) batch = {} - total_response_length = seq_len + max_new_tokens - batch["causal_mask"] = torch.tril( - torch.ones( - size=(total_response_length, total_response_length), - dtype=torch.bool, - ) + total_response_length = max_seq_len + max_new_tokens + batch["causal_mask"] = torch.nn.functional.pad( + torch.tril( + torch.ones( + size=(total_response_length, total_response_length), + dtype=torch.bool, + ) + ), + ( + 0, + max_seq_len - total_response_length, + 0, + max_seq_len - total_response_length, + ), + value=0, ) logging.debug(encoded) @@ -849,6 +878,7 @@ def chat( generator_args.prompt, generator_args.image_prompts, generator_args.max_new_tokens, + generator_args.max_seq_length, ) model_size = sum( diff --git a/torchchat/usages/openai_api.py b/torchchat/usages/openai_api.py index 2c6237437..f2d68881a 100644 --- a/torchchat/usages/openai_api.py +++ b/torchchat/usages/openai_api.py @@ -316,38 +316,22 @@ def _gen_model_inputs_from_openai_completion_request( if not isinstance(self.model, FlamingoModel): prompt = [ {"role": message["role"], "content": message["content"]} - for message in completion_request.messages + for message in messages ] return self._gen_model_input( prompt=prompt, max_new_tokens=completion_request.max_tokens ) # Llama 3.2 11B - prompt = None - images = None - - for message in messages: - torchtune_contents = [] - if isinstance(message["content"], list): - for content_dict in message["content"]: - if content_dict["type"] == "text": - assert ( - prompt is None - ), "At most one text prompt is supported for each request" - prompt = content_dict["text"] - elif content_dict["type"] == "image_url": - assert ( - images is None - ), "At most one image is supported at the moment" - - base64_decoded = base64.b64decode( - content_dict["image_url"].split(";base64,")[1] - ) - images = [Image.open(BytesIO(base64_decoded))] - - assert prompt is not None, "Text prompt must be specified in the request" - - return self._gen_model_input(prompt, images, completion_request.max_tokens) + + prompt = [ + {"role": message["role"], "content": message["content"]} + for message in messages + ] + + return self._gen_model_input( + prompt=prompt, max_new_tokens=completion_request.max_tokens + ) def chunked_completion(self, completion_request: CompletionRequest): """Handle a chat completion request and yield a chunked response.