From d11f0e4197449e66404777cb7c5d53d75a3dbab7 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 24 Sep 2024 13:02:40 -0700 Subject: [PATCH 1/8] update flamingo model for tune --- torchchat/model.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/torchchat/model.py b/torchchat/model.py index e6616dc9d..98f37313f 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -535,18 +535,28 @@ def reset_caches(self): class FlamingoModel(Model): def forward( self, - tokens: Tensor, - encoder_input: Optional[Dict[str, Tensor]] = None, - encoder_mask: Optional[Tensor] = None, + tokens: torch.Tensor, + *, + mask: Optional[torch.Tensor] = None, + encoder_input: Optional[Dict] = None, + encoder_mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, ) -> Tensor: - if encoder_input is None: - return self.model(tokens, encoder_mask=encoder_mask) return self.model( - tokens, encoder_input=encoder_input, encoder_mask=encoder_mask + tokens, + mask=mask, + encoder_input=encoder_input, + encoder_mask=encoder_mask, + input_pos=input_pos, ) - def setup_caches(self, max_batch_size, dtype): - self.model.setup_caches(max_batch_size, dtype=dtype) + def setup_caches(self, batch_size, dtype, encoder_max_seq_len, decoder_max_seq_len): + self.model.setup_caches( + batch_size=batch_size, + dtype=dtype, + encoder_max_seq_len=encoder_max_seq_len, + decoder_max_seq_len=decoder_max_seq_len, + ) def reset_caches(self): self.model.reset_caches() From c1a8ff45faee73636e92ac9b8a4297cc9ca8b171 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 24 Sep 2024 15:10:55 -0700 Subject: [PATCH 2/8] 1/n flamingo e2e ppl --- .gitignore | 3 +++ torchchat/generate.py | 45 ++++++++++++++++++++++++++++++------------- 2 files changed, 35 insertions(+), 13 deletions(-) diff --git a/.gitignore b/.gitignore index 3f25b76c0..044bad856 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,6 @@ system_info.txt # build artifacts checkpoints/ exportedModels/ + +# test script +_torchchat_test_script.py diff --git a/torchchat/generate.py b/torchchat/generate.py index 14c4832e3..a6744d85d 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -21,7 +21,7 @@ import torch._inductor.config try: - from _torchchat_test_script import flamingo_transform, padded_collate + from _torchchat_test_script import flamingo_transform except ImportError: pass @@ -38,8 +38,9 @@ from torchchat.utils.device_info import get_device_info # torchtune model definition dependencies -from torchtune.data import Message -from torchtune.generation._generation import sample as tune_sample +from torchtune.data import Message, padded_collate_tiled_images_and_mask + +from torchtune.generation import sample as tune_sample from torchtune.models.llama3 import llama3_tokenizer from torchtune.training import set_default_dtype @@ -357,15 +358,25 @@ def prefill( if batch is not None: # TODO: Verify sequential prefill works with multimodal models - logits = model(**batch)[:, -1] - return tune_sample(logits, 0, 500) + tokens = batch["tokens"] + if 'encoder_input' in tokens: + encoder_input = tokens['encoder_input'] + else: + encoder_input = None + + mask = batch["causal_mask"][None, :seq_len] + input_pos = batch["input_pos"][None, :seq_len] + encoder_mask = batch["encoder_mask"] + + logits = model(tokens=tokens, mask=mask, encoder_input=encoder_input, input_pos=input_post, encoder_mask=encoder_mask)[:, -1] + return tune_sample(logits, temperature=0, top_k=500) elif sequential_prefill: for i in range(width): x_sliced, ip_sliced = x[:, i].view(-1, 1), input_pos[i].view(-1) # logging.debug(f" x: {x_sliced}, input_pos: {ip_sliced}") logits = model(x_sliced, ip_sliced) # (x[:, i], input_pos[i]) elif self.model.config.model_type == ModelType.Flamingo: - logits = model(x) + assert False, "Flamingo requires batch" else: # input_pos: [B, S] logits = model(x, input_pos) @@ -387,10 +398,10 @@ def decode_one_token( assert input_pos.shape[-1] == 1 x = x.view(1, -1) if model.config.model_type == ModelType.Flamingo: - if batch is not None: - logits = model(x, encoder_mask=batch["encoder_mask"][:, -1:]) - else: - logits = model(x) + assert batch is not None, "Flamingo requires batch" + mask = batch["causal_mask"][None, input_pos.item(), None, :] + encoder_mask = batch["encoder_mask"][:, -1:] + logits = model(x, encoder_mask=encoder_mask, mask=mask, input_pos=input_pos)[:, -1:] else: logits = model(x, input_pos) # print(f"x: {x},\n input_pos: {input_pos}\n") @@ -593,7 +604,7 @@ def generate( self.is_torchtune_model or self.model.config.model_type == ModelType.Flamingo ): - model.setup_caches(max_batch_size=1, dtype=self.dtype) + model.setup_caches(batch_size=1, dtype=self.dtype, encoder_max_seq_len=6404, decoder_max_seq_len=T_new) else: model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) if is_speculative and draft_model is not model: @@ -743,8 +754,16 @@ def chat( transform = flamingo_transform(str(self.tokenizer_args.tokenizer_path)) data = transform({"messages": messages}, inference=True) - batch = padded_collate([data], self.builder_args.device) - batch.pop("mask") + batch = padded_collate_tiled_images_and_mask([data], pad_direction="left", pad_max_images=1) + seq_len = len(data["tokens"]) + total_response_length = seq_len + generator_args.max_new_tokens + batch["causal_mask"] = torch.tril( + torch.ones( + size=(total_response_length, total_response_length), + dtype=torch.bool, + ) + ) + batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len] encoded = batch["tokens"] else: From 148d4ff6c58cc1d987444aaacdab3196a4d3450a Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 24 Sep 2024 16:53:42 -0700 Subject: [PATCH 3/8] flamingo e2e enable --- torchchat/generate.py | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/torchchat/generate.py b/torchchat/generate.py index a6744d85d..ec44e5344 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -359,16 +359,16 @@ def prefill( if batch is not None: # TODO: Verify sequential prefill works with multimodal models tokens = batch["tokens"] - if 'encoder_input' in tokens: - encoder_input = tokens['encoder_input'] + if 'encoder_input' in batch: + encoder_input = batch['encoder_input'] else: encoder_input = None - + + seq_len = tokens.size(1) mask = batch["causal_mask"][None, :seq_len] - input_pos = batch["input_pos"][None, :seq_len] encoder_mask = batch["encoder_mask"] - - logits = model(tokens=tokens, mask=mask, encoder_input=encoder_input, input_pos=input_post, encoder_mask=encoder_mask)[:, -1] + input_pos = input_pos.view(1, -1) + logits = model(tokens=tokens, mask=mask, encoder_input=encoder_input, input_pos=input_pos, encoder_mask=encoder_mask)[:, -1] return tune_sample(logits, temperature=0, top_k=500) elif sequential_prefill: for i in range(width): @@ -604,7 +604,7 @@ def generate( self.is_torchtune_model or self.model.config.model_type == ModelType.Flamingo ): - model.setup_caches(batch_size=1, dtype=self.dtype, encoder_max_seq_len=6404, decoder_max_seq_len=T_new) + model.setup_caches(batch_size=1, dtype=self.dtype, encoder_max_seq_len=6404, decoder_max_seq_len=max_seq_length-1) else: model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) if is_speculative and draft_model is not model: @@ -753,18 +753,19 @@ def chat( ] transform = flamingo_transform(str(self.tokenizer_args.tokenizer_path)) - data = transform({"messages": messages}, inference=True) - batch = padded_collate_tiled_images_and_mask([data], pad_direction="left", pad_max_images=1) - seq_len = len(data["tokens"]) - total_response_length = seq_len + generator_args.max_new_tokens - batch["causal_mask"] = torch.tril( - torch.ones( - size=(total_response_length, total_response_length), - dtype=torch.bool, + + with torch.device(device=self.builder_args.device): + data = transform({"messages": messages}, inference=True) + batch = padded_collate_tiled_images_and_mask([data], pad_direction="left", pad_max_images=1) + seq_len = len(data["tokens"]) + batch["causal_mask"] = torch.tril( + torch.ones( + size=(generator_args.max_new_tokens, generator_args.max_new_tokens), + dtype=torch.bool, + ) ) - ) - batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len] - encoded = batch["tokens"] + batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len] + encoded = batch["tokens"] else: encoded = self.encode_tokens( From f15957e2c942b00618dec2b2b8e37634c08ce9b5 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 24 Sep 2024 17:23:02 -0700 Subject: [PATCH 4/8] bump up tune version --- install/install_requirements.sh | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/install/install_requirements.sh b/install/install_requirements.sh index 47fd5b36d..d7183ee30 100755 --- a/install/install_requirements.sh +++ b/install/install_requirements.sh @@ -52,10 +52,6 @@ PYTORCH_NIGHTLY_VERSION=dev20240901 # Nightly version for torchvision VISION_NIGHTLY_VERSION=dev20240901 -# Nightly version for torchtune -TUNE_NIGHTLY_VERSION=dev20240916 - - # Uninstall triton, as nightly will depend on pytorch-triton, which is one and the same ( set -x @@ -76,7 +72,6 @@ fi REQUIREMENTS_TO_INSTALL=( torch=="2.5.0.${PYTORCH_NIGHTLY_VERSION}" torchvision=="0.20.0.${VISION_NIGHTLY_VERSION}" - torchtune=="0.3.0.${TUNE_NIGHTLY_VERSION}" ) # Install the requirements. --extra-index-url tells pip to look for package @@ -92,6 +87,12 @@ REQUIREMENTS_TO_INSTALL=( $PIP_EXECUTABLE install torchao=="0.5.0" ) +# Rely on the latest tochtune for flamingo support +( + set -x + $PIP_EXECUTABLE install git+https://github.com/pytorch/torchtune.git@18efc81dda1c537bb7c25058ff059b4623ccff58 +) + if [[ -x "$(command -v nvidia-smi)" ]]; then ( set -x From 0ac5f5005bcc3ff0b15dd2fa64b8813be6ecf3aa Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 24 Sep 2024 21:17:34 -0700 Subject: [PATCH 5/8] remove hacky cache size, add comment for magic number --- torchchat/generate.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchchat/generate.py b/torchchat/generate.py index ec44e5344..593029c11 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -604,7 +604,8 @@ def generate( self.is_torchtune_model or self.model.config.model_type == ModelType.Flamingo ): - model.setup_caches(batch_size=1, dtype=self.dtype, encoder_max_seq_len=6404, decoder_max_seq_len=max_seq_length-1) + # 6404 is one-gpu affordable max_seq_length for single image input + model.setup_caches(batch_size=1, dtype=self.dtype, encoder_max_seq_len=6404, decoder_max_seq_len=T_new) else: model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) if is_speculative and draft_model is not model: @@ -758,14 +759,15 @@ def chat( data = transform({"messages": messages}, inference=True) batch = padded_collate_tiled_images_and_mask([data], pad_direction="left", pad_max_images=1) seq_len = len(data["tokens"]) + total_response_length = seq_len + generator_args.max_new_tokens batch["causal_mask"] = torch.tril( torch.ones( - size=(generator_args.max_new_tokens, generator_args.max_new_tokens), + size=(total_response_length, total_response_length), dtype=torch.bool, ) ) batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len] - encoded = batch["tokens"] + encoded = batch["tokens"].view(-1) else: encoded = self.encode_tokens( From 21ffafe1447d6346a4461ef6c21bca8ae7cf1faf Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 24 Sep 2024 22:17:17 -0700 Subject: [PATCH 6/8] dytpe set for input --- torchchat/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchchat/generate.py b/torchchat/generate.py index 593029c11..633161c30 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -755,7 +755,7 @@ def chat( transform = flamingo_transform(str(self.tokenizer_args.tokenizer_path)) - with torch.device(device=self.builder_args.device): + with torch.device(device=self.builder_args.device), set_default_dtype(self.dtype): data = transform({"messages": messages}, inference=True) batch = padded_collate_tiled_images_and_mask([data], pad_direction="left", pad_max_images=1) seq_len = len(data["tokens"]) From 437fd3eab810fbcb76b589d9ae974d81a8b4835b Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 24 Sep 2024 22:41:40 -0700 Subject: [PATCH 7/8] manually cast dtype --- torchchat/generate.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchchat/generate.py b/torchchat/generate.py index 633161c30..6b7dc1432 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -758,6 +758,8 @@ def chat( with torch.device(device=self.builder_args.device), set_default_dtype(self.dtype): data = transform({"messages": messages}, inference=True) batch = padded_collate_tiled_images_and_mask([data], pad_direction="left", pad_max_images=1) + # set_default_dtype can not handle the dtype of the image tensor inside the batch; need to manually cast it + batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(self.dtype) seq_len = len(data["tokens"]) total_response_length = seq_len + generator_args.max_new_tokens batch["causal_mask"] = torch.tril( From 5ce5e9d0ec8422b39dc6e06f2ae104bc81717da8 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Wed, 25 Sep 2024 00:01:03 -0700 Subject: [PATCH 8/8] extra config for deep fusion module --- torchchat/model.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/torchchat/model.py b/torchchat/model.py index 98f37313f..edb0ce3d5 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -458,6 +458,13 @@ def build_model(self) -> nn.Module: modules[name] = module_class(TransformerArgs.from_params(config_args)) else: modules[name] = module_class(**config_args) + + # Temporary add extra params to the DeepFusionModel. + # TODO: Remove it once we can make fusion model configurable in model_param. + if recipe.fusion_class == DeepFusionModel: + modules["encoder_trainable"] = False + modules["decoder_trainable"] = False + modules["fusion_trainable"] = False return recipe.fusion_class(**modules)