diff --git a/torchchat/generate.py b/torchchat/generate.py index 0a22f9160..a8501328e 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -4,6 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import argparse +import base64 import itertools import logging import os @@ -12,6 +13,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass +from io import BytesIO from os import PathLike from pathlib import Path from typing import Any, Dict, List, Optional, Sequence, Tuple, Union @@ -600,9 +602,8 @@ def generate( if len(prompt.shape) > 1: prompt = prompt.squeeze(0) - T = prompt.size(0) - max_new_tokens = min(max_new_tokens, max_seq_length - start_pos - T) - T_new = T + max_new_tokens + prompt_length = prompt.size(0) + max_new_tokens = min(max_new_tokens, max_seq_length - start_pos - prompt_length) # set up caches only if first inference if start_pos == 0: model = model.to(device=device) @@ -616,7 +617,7 @@ def generate( batch_size=1, dtype=self.dtype, encoder_max_seq_len=6404, - decoder_max_seq_len=T_new, + decoder_max_seq_len=max_seq_length, ) else: model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) @@ -629,7 +630,7 @@ def generate( model.reset_caches() input_pos = torch.arange( - start_pos, T + start_pos, device=device, dtype=torch.int + start_pos, prompt_length + start_pos, device=device, dtype=torch.int ) prefill_t0 = time.perf_counter() @@ -655,7 +656,9 @@ def generate( # max_new_tokens <= 2 means we are effectively not calling decode_n_tokens(). callback(next_token.clone().view(-1), done_generating=max_new_tokens <= 2) - input_pos = torch.tensor([start_pos + T], device=device, dtype=torch.int) + input_pos = torch.tensor( + [start_pos + prompt_length], device=device, dtype=torch.int + ) accept_counts = [0] * ( speculate_k + 1 ) # creates array of [0, 0, 0, ...] that is speculate_k + 1 long @@ -678,7 +681,7 @@ def generate( ) accept_counts[len(next_tokens) - 1] += 1 - num_added = min(T_new - input_pos - 1, len(next_tokens)) + num_added = min(max_new_tokens - input_pos - 1, len(next_tokens)) for token in next_tokens[:num_added,]: callback(token) yield token, None @@ -741,6 +744,7 @@ def _gen_model_input( prompt: Union[str | List[Any]], image_prompts: Optional[List[str | Image.Image]] = None, max_new_tokens: Optional[int] = None, + max_seq_len: Optional[int] = 2048, ) -> Tuple[torch.Tensor, Optional[Dict[str, Any]]]: """ Convert prompt and image prompts into consumable model input args. @@ -757,7 +761,7 @@ def _gen_model_input( Tuple[torch.Tensor, Optional[Dict[str, Any]]]: Encoded prompt and batch config for multimodal models. """ - # Not Llama 3.2 11B + # Text-Only model if self.model.config.model_type != ModelType.Flamingo: # Single String prompt if isinstance(prompt, str): @@ -778,32 +782,69 @@ def _gen_model_input( assert ( image_prompts is None or len(image_prompts) == 1 ), "At most one image is supported at the moment" + if image_prompts and isinstance(image_prompts[0], str): images = [Image.open(image_prompts[0])] else: - images = image_prompts + images = None assert ( max_new_tokens is not None ), "max_new_tokens must be specified for Flamingo models" - assert isinstance( - prompt, str - ), "(Currently) prompt must be a str for Flamingo models" - is_multimodal = images is not None - content = [{"type": "text", "content": prompt}] + image_found = False + messages = [] + for message in prompt: + if isinstance(message["content"], str): + if not image_found and image_prompts: + messages.append( + Message( + role=message["role"], + content=[ + {"type": "image", "content": images[0]}, + {"type": "text", "content": message["content"]}, + ], + ) + ) + image_found = True + else: + messages.append(Message(**message)) + + elif isinstance(message["content"], list): + images = None + for content_dict in message["content"]: + if content_dict["type"] == "text": + prompt_arg = content_dict["text"] + elif content_dict["type"] == "image_url": + assert ( + images is None + ), "At most one image is supported at the moment" + + base64_decoded = base64.b64decode( + content_dict["image_url"].split(";base64,")[1] + ) + images = [Image.open(BytesIO(base64_decoded))] + image_found = True + + is_multimodal = images is not None + content = [{"type": "text", "content": prompt_arg}] + + if is_multimodal: + content = [{"type": "image", "content": images[0]}] + content - if is_multimodal: - content = [{"type": "image", "content": images[0]}] + content + messages.append( + Message( + role=message["role"], + content=content, + ) + ) - messages = [ + messages.append( Message( - role="user", - content=content, - eot=True, - ), - Message(role="assistant", content=""), - ] + role="assistant", + content="", + ) + ) transform = llama3_2_vision_transform(str(self.tokenizer_args.tokenizer_path)) @@ -812,7 +853,7 @@ def _gen_model_input( with device, set_default_dtype(self.dtype): data = transform({"messages": messages}, inference=True) - if is_multimodal: + if image_found: batch = padded_collate_tiled_images_and_mask( [data], pad_direction="left", pad_max_images=1 ) @@ -822,17 +863,27 @@ def _gen_model_input( batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to( self.dtype ) + else: encoded = torch.tensor(data["tokens"], device=device).view(-1) seq_len = encoded.size(0) batch = {} total_response_length = seq_len + max_new_tokens - batch["causal_mask"] = torch.tril( - torch.ones( - size=(total_response_length, total_response_length), - dtype=torch.bool, - ) + batch["causal_mask"] = torch.nn.functional.pad( + torch.tril( + torch.ones( + size=(total_response_length, total_response_length), + dtype=torch.bool, + ) + ), + ( + 0, + max_seq_len - total_response_length, + 0, + max_seq_len - total_response_length, + ), + value=0, ) logging.debug(encoded) @@ -845,12 +896,6 @@ def chat( if generator_args.chat_mode: print("Starting Interactive Chat") - encoded, batch = self._gen_model_input( - generator_args.prompt, - generator_args.image_prompts, - generator_args.max_new_tokens, - ) - model_size = sum( [ p.numel() * p.dtype.itemsize @@ -896,6 +941,12 @@ def chat( max_seq_length = ( text_transformer_args.max_seq_length if text_transformer_args else 2048 ) + encoded, batch = self._gen_model_input( + [{"role": "user", "content": generator_args.prompt}], + generator_args.image_prompts, + generator_args.max_new_tokens, + max_seq_length, + ) if generator_args.chat_mode: print( @@ -907,16 +958,16 @@ def chat( if get_system_prompt == "y" or get_system_prompt == "Y": self.system_prompt = input("What is your system prompt? \n") - elif not generator_args.is_torchtune_model: - max_seq_length = min( - encoded.size(0) + generator_args.max_new_tokens, - ( - text_transformer_args.block_size - if text_transformer_args is not None - else 2048 - ), - max_seq_length, - ) + # elif not generator_args.is_torchtune_model: + # max_seq_length = min( + # encoded.size(0) + generator_args.max_new_tokens, + # ( + # text_transformer_args.block_size + # if text_transformer_args is not None + # else 2048 + # ), + # max_seq_length, + # ) max_seq_length = ( max_seq_length + self.speculative_builder_args.speculate_k + 1 diff --git a/torchchat/usages/browser.py b/torchchat/usages/browser.py index 6bd760ccc..3a2eea4ac 100644 --- a/torchchat/usages/browser.py +++ b/torchchat/usages/browser.py @@ -10,30 +10,39 @@ from openai import OpenAI +st.set_page_config(page_title="torchchat", page_icon="🤖") st.title("torchchat") + start_state = [ { "role": "system", - "content": "You're a helpful assistant - have fun.", + "content": "You're a helpful assistant - be brief and have fun.", }, {"role": "assistant", "content": "How can I help you?"}, ] -st.session_state.uploader_key = 0 + +def reset_chat(): + st.session_state["messages"] = start_state + st.session_state["conversation_images"] = [] -def reset_per_message_state(): - # Catch all function for anything that should be reset between each message. - _update_uploader_key() +if "messages" not in st.session_state: + st.session_state.messages = start_state +if "conversation_images" not in st.session_state: + st.session_state.conversation_images = [] -def _update_uploader_key(): - # Increment the uploader key to reset the file uploader after each message. - st.session_state.uploader_key = int(time.time()) +def _upload_image_prompts(file_uploads): + for file in file_uploads: + st.session_state.conversation_images.append(file) with st.sidebar: + if st.button("Reset Chat", type="primary"): + reset_chat() + # API Configuration api_base_url = st.text_input( label="API Base URL", @@ -41,6 +50,11 @@ def _update_uploader_key(): help="The base URL for the OpenAI API to connect to", ) + client = OpenAI( + base_url=api_base_url, + api_key="813", # The OpenAI API requires an API key, but since we don't consume it, this can be any non-empty string. + ) + st.divider() temperature = st.slider( "Temperature", min_value=0.0, max_value=1.0, value=1.0, step=0.01 @@ -49,28 +63,6 @@ def _update_uploader_key(): response_max_tokens = st.slider( "Max Response Tokens", min_value=10, max_value=1000, value=250, step=10 ) - if st.button("Reset Chat", type="primary"): - st.session_state["messages"] = start_state - - image_prompts = st.file_uploader( - "Image Prompts", - type=["jpeg"], - accept_multiple_files=True, - key=st.session_state.uploader_key, - ) - - for image in image_prompts: - st.image(image) - - -client = OpenAI( - base_url=api_base_url, - api_key="813", # The OpenAI API requires an API key, but since we don't consume it, this can be any non-empty string. -) - -if "messages" not in st.session_state: - st.session_state["messages"] = start_state - for msg in st.session_state.messages: with st.chat_message(msg["role"]): @@ -86,6 +78,7 @@ def _update_uploader_key(): st.write(content["text"]) elif type(msg["content"]) is dict: if msg["content"]["type"] == "image_url": + pass st.image(msg["content"]["image_url"]) else: st.write(msg["content"]["text"]) @@ -98,8 +91,8 @@ def _update_uploader_key(): if prompt := st.chat_input(): user_message = {"role": "user", "content": [{"type": "text", "text": prompt}]} - if image_prompts: - for image_prompt in image_prompts: + if len(st.session_state.conversation_images) > 0: + for image_prompt in st.session_state.conversation_images: extension = Path(image_prompt.name).suffix.strip(".") image_bytes = image_prompt.getvalue() base64_encoded = base64.b64encode(image_bytes).decode("utf-8") @@ -113,11 +106,9 @@ def _update_uploader_key(): with st.chat_message("user"): st.write(prompt) - for img in image_prompts: + for img in st.session_state.conversation_images: st.image(img) - - image_prompts = None - reset_per_message_state() + st.session_state.conversation_images = [] with st.chat_message("assistant"), st.status( "Generating... ", expanded=True @@ -154,3 +145,21 @@ def get_streamed_completion(completion_generator): print(e) st.session_state.messages.append({"role": "assistant", "content": response}) + +# Note: This section needs to be at the end of the file to ensure that the session state is updated before the sidebar is rendered. +with st.sidebar: + st.divider() + + with st.form("image_uploader", clear_on_submit=True): + file_uploads = st.file_uploader( + "Upload Image Prompts", + type=["jpeg"], + accept_multiple_files=True, + ) + submitted = st.form_submit_button("Attach images to chat message") + if submitted: + _upload_image_prompts(file_uploads) + + st.markdown("Image Prompts") + for image in st.session_state.conversation_images: + st.image(image) diff --git a/torchchat/usages/openai_api.py b/torchchat/usages/openai_api.py index 2c6237437..f2d68881a 100644 --- a/torchchat/usages/openai_api.py +++ b/torchchat/usages/openai_api.py @@ -316,38 +316,22 @@ def _gen_model_inputs_from_openai_completion_request( if not isinstance(self.model, FlamingoModel): prompt = [ {"role": message["role"], "content": message["content"]} - for message in completion_request.messages + for message in messages ] return self._gen_model_input( prompt=prompt, max_new_tokens=completion_request.max_tokens ) # Llama 3.2 11B - prompt = None - images = None - - for message in messages: - torchtune_contents = [] - if isinstance(message["content"], list): - for content_dict in message["content"]: - if content_dict["type"] == "text": - assert ( - prompt is None - ), "At most one text prompt is supported for each request" - prompt = content_dict["text"] - elif content_dict["type"] == "image_url": - assert ( - images is None - ), "At most one image is supported at the moment" - - base64_decoded = base64.b64decode( - content_dict["image_url"].split(";base64,")[1] - ) - images = [Image.open(BytesIO(base64_decoded))] - - assert prompt is not None, "Text prompt must be specified in the request" - - return self._gen_model_input(prompt, images, completion_request.max_tokens) + + prompt = [ + {"role": message["role"], "content": message["content"]} + for message in messages + ] + + return self._gen_model_input( + prompt=prompt, max_new_tokens=completion_request.max_tokens + ) def chunked_completion(self, completion_request: CompletionRequest): """Handle a chat completion request and yield a chunked response.