This repository was archived by the owner on Sep 10, 2025. It is now read-only.
  
  
  - 
                Notifications
    You must be signed in to change notification settings 
- Fork 248
Bump torchtune pin to a 9-24 commit; Update Flamingo Definition #1195
          
     Merged
      
      
    
  
     Merged
                    Changes from 4 commits
      Commits
    
    
            Show all changes
          
          
            8 commits
          
        
        Select commit
          Hold shift + click to select a range
      
      d11f0e4
              
                update flamingo model for tune
              
              
                Gasoonjia c1a8ff4
              
                1/n flamingo e2e ppl
              
              
                Gasoonjia 148d4ff
              
                flamingo e2e enable
              
              
                Gasoonjia f15957e
              
                bump up tune version
              
              
                Gasoonjia 0ac5f50
              
                remove hacky cache size, add comment for magic number
              
              
                Gasoonjia 21ffafe
              
                dytpe set for input
              
              
                Gasoonjia 437fd3e
              
                manually cast dtype
              
              
                Gasoonjia 5ce5e9d
              
                extra config for deep fusion module
              
              
                Gasoonjia File filter
Filter by extension
Conversations
          Failed to load comments.   
        
        
          
      Loading
        
  Jump to
        
          Jump to file
        
      
      
          Failed to load files.   
        
        
          
      Loading
        
  Diff view
Diff view
There are no files selected for viewing
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -30,3 +30,6 @@ system_info.txt | |
| # build artifacts | ||
| checkpoints/ | ||
| exportedModels/ | ||
|  | ||
| # test script | ||
| _torchchat_test_script.py | ||
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -21,7 +21,7 @@ | |
| import torch._inductor.config | ||
|  | ||
| try: | ||
| from _torchchat_test_script import flamingo_transform, padded_collate | ||
| from _torchchat_test_script import flamingo_transform | ||
| except ImportError: | ||
| pass | ||
|  | ||
|  | @@ -38,8 +38,9 @@ | |
| from torchchat.utils.device_info import get_device_info | ||
|  | ||
| # torchtune model definition dependencies | ||
| from torchtune.data import Message | ||
| from torchtune.generation._generation import sample as tune_sample | ||
| from torchtune.data import Message, padded_collate_tiled_images_and_mask | ||
|  | ||
| from torchtune.generation import sample as tune_sample | ||
| from torchtune.models.llama3 import llama3_tokenizer | ||
| from torchtune.training import set_default_dtype | ||
|  | ||
|  | @@ -357,15 +358,25 @@ def prefill( | |
|  | ||
| if batch is not None: | ||
| # TODO: Verify sequential prefill works with multimodal models | ||
| logits = model(**batch)[:, -1] | ||
| return tune_sample(logits, 0, 500) | ||
| tokens = batch["tokens"] | ||
| if 'encoder_input' in batch: | ||
| encoder_input = batch['encoder_input'] | ||
| else: | ||
| encoder_input = None | ||
|  | ||
| seq_len = tokens.size(1) | ||
| mask = batch["causal_mask"][None, :seq_len] | ||
| encoder_mask = batch["encoder_mask"] | ||
| input_pos = input_pos.view(1, -1) | ||
| logits = model(tokens=tokens, mask=mask, encoder_input=encoder_input, input_pos=input_pos, encoder_mask=encoder_mask)[:, -1] | ||
| return tune_sample(logits, temperature=0, top_k=500) | ||
| elif sequential_prefill: | ||
| for i in range(width): | ||
| x_sliced, ip_sliced = x[:, i].view(-1, 1), input_pos[i].view(-1) | ||
| # logging.debug(f"<sliced> x: {x_sliced}, input_pos: {ip_sliced}") | ||
| logits = model(x_sliced, ip_sliced) # (x[:, i], input_pos[i]) | ||
| elif self.model.config.model_type == ModelType.Flamingo: | ||
| logits = model(x) | ||
| assert False, "Flamingo requires batch" | ||
| else: | ||
| # input_pos: [B, S] | ||
| logits = model(x, input_pos) | ||
|  | @@ -387,10 +398,10 @@ def decode_one_token( | |
| assert input_pos.shape[-1] == 1 | ||
| x = x.view(1, -1) | ||
| if model.config.model_type == ModelType.Flamingo: | ||
| if batch is not None: | ||
| logits = model(x, encoder_mask=batch["encoder_mask"][:, -1:]) | ||
| else: | ||
| logits = model(x) | ||
| assert batch is not None, "Flamingo requires batch" | ||
| mask = batch["causal_mask"][None, input_pos.item(), None, :] | ||
| encoder_mask = batch["encoder_mask"][:, -1:] | ||
| logits = model(x, encoder_mask=encoder_mask, mask=mask, input_pos=input_pos)[:, -1:] | ||
| else: | ||
| logits = model(x, input_pos) | ||
| # print(f"x: {x},\n input_pos: {input_pos}\n") | ||
|  | @@ -593,7 +604,7 @@ def generate( | |
| self.is_torchtune_model | ||
| or self.model.config.model_type == ModelType.Flamingo | ||
| ): | ||
| model.setup_caches(max_batch_size=1, dtype=self.dtype) | ||
| model.setup_caches(batch_size=1, dtype=self.dtype, encoder_max_seq_len=6404, decoder_max_seq_len=max_seq_length-1) | ||
|          | ||
| else: | ||
| model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) | ||
| if is_speculative and draft_model is not model: | ||
|  | @@ -742,10 +753,19 @@ def chat( | |
| ] | ||
|  | ||
| transform = flamingo_transform(str(self.tokenizer_args.tokenizer_path)) | ||
| data = transform({"messages": messages}, inference=True) | ||
| batch = padded_collate([data], self.builder_args.device) | ||
| batch.pop("mask") | ||
| encoded = batch["tokens"] | ||
|  | ||
| with torch.device(device=self.builder_args.device): | ||
| data = transform({"messages": messages}, inference=True) | ||
| batch = padded_collate_tiled_images_and_mask([data], pad_direction="left", pad_max_images=1) | ||
| seq_len = len(data["tokens"]) | ||
| batch["causal_mask"] = torch.tril( | ||
| torch.ones( | ||
| size=(generator_args.max_new_tokens, generator_args.max_new_tokens), | ||
| dtype=torch.bool, | ||
| ) | ||
| ) | ||
| batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len] | ||
| encoded = batch["tokens"] | ||
|  | ||
| else: | ||
| encoded = self.encode_tokens( | ||
|  | ||
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
      
      Oops, something went wrong.
        
    
  
  Add this suggestion to a batch that can be applied as a single commit.
  This suggestion is invalid because no changes were made to the code.
  Suggestions cannot be applied while the pull request is closed.
  Suggestions cannot be applied while viewing a subset of changes.
  Only one suggestion per line can be applied in a batch.
  Add this suggestion to a batch that can be applied as a single commit.
  Applying suggestions on deleted lines is not supported.
  You must change the existing code in this line in order to create a valid suggestion.
  Outdated suggestions cannot be applied.
  This suggestion has been applied or marked resolved.
  Suggestions cannot be applied from pending reviews.
  Suggestions cannot be applied on multi-line comments.
  Suggestions cannot be applied while the pull request is queued to merge.
  Suggestion cannot be applied right now. Please check back later.
  
    
  
    
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This works on Mac now right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah and mac tests passed