diff --git a/.gitignore b/.gitignore index 3f25b76c0..044bad856 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,6 @@ system_info.txt # build artifacts checkpoints/ exportedModels/ + +# test script +_torchchat_test_script.py diff --git a/install/install_requirements.sh b/install/install_requirements.sh index 47fd5b36d..d7183ee30 100755 --- a/install/install_requirements.sh +++ b/install/install_requirements.sh @@ -52,10 +52,6 @@ PYTORCH_NIGHTLY_VERSION=dev20240901 # Nightly version for torchvision VISION_NIGHTLY_VERSION=dev20240901 -# Nightly version for torchtune -TUNE_NIGHTLY_VERSION=dev20240916 - - # Uninstall triton, as nightly will depend on pytorch-triton, which is one and the same ( set -x @@ -76,7 +72,6 @@ fi REQUIREMENTS_TO_INSTALL=( torch=="2.5.0.${PYTORCH_NIGHTLY_VERSION}" torchvision=="0.20.0.${VISION_NIGHTLY_VERSION}" - torchtune=="0.3.0.${TUNE_NIGHTLY_VERSION}" ) # Install the requirements. --extra-index-url tells pip to look for package @@ -92,6 +87,12 @@ REQUIREMENTS_TO_INSTALL=( $PIP_EXECUTABLE install torchao=="0.5.0" ) +# Rely on the latest tochtune for flamingo support +( + set -x + $PIP_EXECUTABLE install git+https://github.com/pytorch/torchtune.git@18efc81dda1c537bb7c25058ff059b4623ccff58 +) + if [[ -x "$(command -v nvidia-smi)" ]]; then ( set -x diff --git a/torchchat/generate.py b/torchchat/generate.py index 14c4832e3..6b7dc1432 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -21,7 +21,7 @@ import torch._inductor.config try: - from _torchchat_test_script import flamingo_transform, padded_collate + from _torchchat_test_script import flamingo_transform except ImportError: pass @@ -38,8 +38,9 @@ from torchchat.utils.device_info import get_device_info # torchtune model definition dependencies -from torchtune.data import Message -from torchtune.generation._generation import sample as tune_sample +from torchtune.data import Message, padded_collate_tiled_images_and_mask + +from torchtune.generation import sample as tune_sample from torchtune.models.llama3 import llama3_tokenizer from torchtune.training import set_default_dtype @@ -357,15 +358,25 @@ def prefill( if batch is not None: # TODO: Verify sequential prefill works with multimodal models - logits = model(**batch)[:, -1] - return tune_sample(logits, 0, 500) + tokens = batch["tokens"] + if 'encoder_input' in batch: + encoder_input = batch['encoder_input'] + else: + encoder_input = None + + seq_len = tokens.size(1) + mask = batch["causal_mask"][None, :seq_len] + encoder_mask = batch["encoder_mask"] + input_pos = input_pos.view(1, -1) + logits = model(tokens=tokens, mask=mask, encoder_input=encoder_input, input_pos=input_pos, encoder_mask=encoder_mask)[:, -1] + return tune_sample(logits, temperature=0, top_k=500) elif sequential_prefill: for i in range(width): x_sliced, ip_sliced = x[:, i].view(-1, 1), input_pos[i].view(-1) # logging.debug(f" x: {x_sliced}, input_pos: {ip_sliced}") logits = model(x_sliced, ip_sliced) # (x[:, i], input_pos[i]) elif self.model.config.model_type == ModelType.Flamingo: - logits = model(x) + assert False, "Flamingo requires batch" else: # input_pos: [B, S] logits = model(x, input_pos) @@ -387,10 +398,10 @@ def decode_one_token( assert input_pos.shape[-1] == 1 x = x.view(1, -1) if model.config.model_type == ModelType.Flamingo: - if batch is not None: - logits = model(x, encoder_mask=batch["encoder_mask"][:, -1:]) - else: - logits = model(x) + assert batch is not None, "Flamingo requires batch" + mask = batch["causal_mask"][None, input_pos.item(), None, :] + encoder_mask = batch["encoder_mask"][:, -1:] + logits = model(x, encoder_mask=encoder_mask, mask=mask, input_pos=input_pos)[:, -1:] else: logits = model(x, input_pos) # print(f"x: {x},\n input_pos: {input_pos}\n") @@ -593,7 +604,8 @@ def generate( self.is_torchtune_model or self.model.config.model_type == ModelType.Flamingo ): - model.setup_caches(max_batch_size=1, dtype=self.dtype) + # 6404 is one-gpu affordable max_seq_length for single image input + model.setup_caches(batch_size=1, dtype=self.dtype, encoder_max_seq_len=6404, decoder_max_seq_len=T_new) else: model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) if is_speculative and draft_model is not model: @@ -742,10 +754,22 @@ def chat( ] transform = flamingo_transform(str(self.tokenizer_args.tokenizer_path)) - data = transform({"messages": messages}, inference=True) - batch = padded_collate([data], self.builder_args.device) - batch.pop("mask") - encoded = batch["tokens"] + + with torch.device(device=self.builder_args.device), set_default_dtype(self.dtype): + data = transform({"messages": messages}, inference=True) + batch = padded_collate_tiled_images_and_mask([data], pad_direction="left", pad_max_images=1) + # set_default_dtype can not handle the dtype of the image tensor inside the batch; need to manually cast it + batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(self.dtype) + seq_len = len(data["tokens"]) + total_response_length = seq_len + generator_args.max_new_tokens + batch["causal_mask"] = torch.tril( + torch.ones( + size=(total_response_length, total_response_length), + dtype=torch.bool, + ) + ) + batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len] + encoded = batch["tokens"].view(-1) else: encoded = self.encode_tokens( diff --git a/torchchat/model.py b/torchchat/model.py index e6616dc9d..edb0ce3d5 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -458,6 +458,13 @@ def build_model(self) -> nn.Module: modules[name] = module_class(TransformerArgs.from_params(config_args)) else: modules[name] = module_class(**config_args) + + # Temporary add extra params to the DeepFusionModel. + # TODO: Remove it once we can make fusion model configurable in model_param. + if recipe.fusion_class == DeepFusionModel: + modules["encoder_trainable"] = False + modules["decoder_trainable"] = False + modules["fusion_trainable"] = False return recipe.fusion_class(**modules) @@ -535,18 +542,28 @@ def reset_caches(self): class FlamingoModel(Model): def forward( self, - tokens: Tensor, - encoder_input: Optional[Dict[str, Tensor]] = None, - encoder_mask: Optional[Tensor] = None, + tokens: torch.Tensor, + *, + mask: Optional[torch.Tensor] = None, + encoder_input: Optional[Dict] = None, + encoder_mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, ) -> Tensor: - if encoder_input is None: - return self.model(tokens, encoder_mask=encoder_mask) return self.model( - tokens, encoder_input=encoder_input, encoder_mask=encoder_mask + tokens, + mask=mask, + encoder_input=encoder_input, + encoder_mask=encoder_mask, + input_pos=input_pos, ) - def setup_caches(self, max_batch_size, dtype): - self.model.setup_caches(max_batch_size, dtype=dtype) + def setup_caches(self, batch_size, dtype, encoder_max_seq_len, decoder_max_seq_len): + self.model.setup_caches( + batch_size=batch_size, + dtype=dtype, + encoder_max_seq_len=encoder_max_seq_len, + decoder_max_seq_len=decoder_max_seq_len, + ) def reset_caches(self): self.model.reset_caches()