diff --git a/torchchat/generate.py b/torchchat/generate.py index 42abe664c..0a22f9160 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -20,10 +20,17 @@ import torch._dynamo.config import torch._inductor.config -from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform - from PIL import Image +# torchtune model definition dependencies +from torchtune.data import Message, padded_collate_tiled_images_and_mask + +from torchtune.generation import sample as tune_sample +from torchtune.models.llama3 import llama3_tokenizer + +from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform +from torchtune.training import set_default_dtype + from torchchat.cli.builder import ( _initialize_model, _initialize_tokenizer, @@ -34,13 +41,6 @@ from torchchat.utils.build_utils import device_sync, set_precision from torchchat.utils.device_info import get_device_info -# torchtune model definition dependencies -from torchtune.data import Message, padded_collate_tiled_images_and_mask - -from torchtune.generation import sample as tune_sample -from torchtune.models.llama3 import llama3_tokenizer -from torchtune.training import set_default_dtype - class _ChatFormatter(ABC): def __init__(self, tokenizer): @@ -357,8 +357,8 @@ def prefill( # TODO: Verify sequential prefill works with multimodal models is_multimodal = True - if 'encoder_input' in batch: - encoder_input = batch['encoder_input'] + if "encoder_input" in batch: + encoder_input = batch["encoder_input"] encoder_mask = batch["encoder_mask"] is_multimodal = True else: @@ -369,7 +369,13 @@ def prefill( seq_len = x.size(1) mask = batch["causal_mask"][None, :seq_len] input_pos = input_pos.view(1, -1) - logits = model(tokens=x, mask=mask, encoder_input=encoder_input, input_pos=input_pos, encoder_mask=encoder_mask)[:, -1] + logits = model( + tokens=x, + mask=mask, + encoder_input=encoder_input, + input_pos=input_pos, + encoder_mask=encoder_mask, + )[:, -1] if is_multimodal: batch["encoder_mask"] = batch["encoder_mask"][:, -1:] @@ -404,7 +410,9 @@ def decode_one_token( assert batch is not None, "Flamingo requires batch" mask = batch["causal_mask"][None, input_pos.item(), None, :] encoder_mask = batch["encoder_mask"] if "encoder_mask" in batch else None - logits = model(x, encoder_mask=encoder_mask, mask=mask, input_pos=input_pos)[:, -1:] + logits = model( + x, encoder_mask=encoder_mask, mask=mask, input_pos=input_pos + )[:, -1:] else: logits = model(x, input_pos) # print(f"x: {x},\n input_pos: {input_pos}\n") @@ -492,7 +500,6 @@ def decode_n_tokens( next_prob.clone() if next_prob is not None else None ) - def model_forward(self, model, x, input_pos): return model(x, input_pos) @@ -605,7 +612,12 @@ def generate( or self.model.config.model_type == ModelType.Flamingo ): # 6404 is one-gpu affordable max_seq_length for single image input - model.setup_caches(batch_size=1, dtype=self.dtype, encoder_max_seq_len=6404, decoder_max_seq_len=T_new) + model.setup_caches( + batch_size=1, + dtype=self.dtype, + encoder_max_seq_len=6404, + decoder_max_seq_len=T_new, + ) else: model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) if is_speculative and draft_model is not model: @@ -731,9 +743,9 @@ def _gen_model_input( max_new_tokens: Optional[int] = None, ) -> Tuple[torch.Tensor, Optional[Dict[str, Any]]]: """ - Convert prompt and image prompts into consumable model input args. + Convert prompt and image prompts into consumable model input args. - When prompt is a list, the anticipated format is OpenAI API Inspired: + When prompt is a list, the anticipated format is OpenAI API Inspired: [ ..., {"role": message["role"], "content": message["content"]}, ...] Args: @@ -826,15 +838,18 @@ def _gen_model_input( logging.debug(encoded) return encoded, batch - def chat( self, generator_args: GeneratorArgs, ): if generator_args.chat_mode: print("Starting Interactive Chat") - - encoded, batch = self._gen_model_input(generator_args.prompt, generator_args.image_prompts, generator_args.max_new_tokens) + + encoded, batch = self._gen_model_input( + generator_args.prompt, + generator_args.image_prompts, + generator_args.max_new_tokens, + ) model_size = sum( [ @@ -900,7 +915,7 @@ def chat( if text_transformer_args is not None else 2048 ), - max_seq_length + max_seq_length, ) max_seq_length = ( diff --git a/torchchat/usages/openai_api.py b/torchchat/usages/openai_api.py index 93de6e0ec..2c6237437 100644 --- a/torchchat/usages/openai_api.py +++ b/torchchat/usages/openai_api.py @@ -19,16 +19,16 @@ from PIL import Image +from torchtune.data import Message, padded_collate_tiled_images_and_mask + +from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform + from torchchat.cli.download import is_model_downloaded, load_model_configs from torchchat.generate import Generator, GeneratorArgs from torchchat.model import FlamingoModel from torchchat.utils.build_utils import device_sync -from torchtune.data import Message, padded_collate_tiled_images_and_mask - -from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform - """Dataclasses defined around the objects used the OpenAI API Chat specification. @@ -291,7 +291,9 @@ def __init__(self, *args, **kwargs): ) except: self.max_seq_length = 2048 - print(f"can not find max_seq_length in model config, use default value: {self.max_seq_length}") + print( + f"can not find max_seq_length in model config, use default value: {self.max_seq_length}" + ) # The System fingerprint is a unique identifier for the model and its configuration. self.system_fingerprint = ( f"{self.builder_args.device}_{self.builder_args.precision}"