|
| 1 | + |
| 2 | +import torch |
| 3 | +import sys |
| 4 | +import os |
| 5 | + |
| 6 | +from torchtune import training |
| 7 | +from torchtune.models.flamingo import flamingo_decoder, flamingo_vision_encoder, FlamingoTransform |
| 8 | +from torchtune.modules.model_fusion import DeepFusionModel |
| 9 | + |
| 10 | +from torchchat.model import Model |
| 11 | + |
| 12 | +import re |
| 13 | + |
| 14 | +from typing import Dict |
| 15 | +from torchtune.generation._generation import sample |
| 16 | +from torchtune.training import set_default_dtype |
| 17 | +import numpy as np |
| 18 | +import PIL |
| 19 | + |
| 20 | +from torchtune.data import Message |
| 21 | + |
| 22 | +def flamingo_transform(tokenizer_path): |
| 23 | + return FlamingoTransform( |
| 24 | + tokenizer_path, |
| 25 | + tile_size=448, |
| 26 | + patch_size=14, |
| 27 | + max_num_tiles=4, |
| 28 | + max_seq_len=8192, |
| 29 | + encoder_max_seq_len=4100, |
| 30 | + image_mean=(0.48145466, 0.4578275, 0.40821073), |
| 31 | + image_std=(0.26862954, 0.26130258, 0.27577711), |
| 32 | + prompt_template=None, |
| 33 | + ) |
| 34 | + |
| 35 | +def padded_collate(batch, device='cuda', dtype=torch.bfloat16, padding_idx=0): |
| 36 | + # Placeholder Collator until https://github.com/pytorch/torchtune/pull/1156 lands |
| 37 | + assert len(batch) == 1, "Test collate function only supports bs = 1" |
| 38 | + sample = batch[0] |
| 39 | + sample["tokens"] = torch.Tensor(sample["tokens"])[None, ...].to(device).long() |
| 40 | + sample["mask"] = torch.Tensor(sample["mask"])[None, ...].to(device).bool() |
| 41 | + sample["encoder_input"]["images"] = torch.stack(sample["encoder_input"]["images"])[None, ...].to(device) |
| 42 | + sample["encoder_input"]["aspect_ratio"] = torch.stack(sample["encoder_input"]["aspect_ratio"])[None, ...].to(device) |
| 43 | + assert len(sample["encoder_mask"]), "Test collate function only supports 1 image per sequence" |
| 44 | + # Pad encoder mask to max_num_tiles sequence length (4100) |
| 45 | + s_x, s_y = sample["encoder_mask"][0].shape |
| 46 | + mask_padding = torch.zeros((s_x, 4100 - s_y), dtype=torch.bool) |
| 47 | + encoder_mask = torch.cat([sample["encoder_mask"][0], mask_padding], dim=1) |
| 48 | + sample["encoder_mask"] = encoder_mask[None, ...].to(device) |
| 49 | + return sample |
| 50 | + |
| 51 | + |
| 52 | + |
| 53 | +_FROM_META = { |
| 54 | + "text_model.tok_embeddings.weight": "decoder.tok_embeddings.weight", |
| 55 | + "text_model.learnable_embedding.weight": "decoder.tok_embeddings.fusion_embedding.weight", |
| 56 | + "text_model.norm.weight": "decoder.norm.scale", |
| 57 | + "text_model.output.weight": "decoder.output.weight", |
| 58 | + |
| 59 | + "text_model.layers.{}.attention_norm.weight": "decoder.layers.{}.sa_norm.scale", |
| 60 | + "text_model.layers.{}.attention.wq.weight": "decoder.layers.{}.attn.q_proj.weight", |
| 61 | + "text_model.layers.{}.attention.wk.weight": "decoder.layers.{}.attn.k_proj.weight", |
| 62 | + "text_model.layers.{}.attention.wv.weight": "decoder.layers.{}.attn.v_proj.weight", |
| 63 | + "text_model.layers.{}.attention.wo.weight": "decoder.layers.{}.attn.output_proj.weight", |
| 64 | + "text_model.layers.{}.ffn_norm.weight": "decoder.layers.{}.mlp_norm.scale", |
| 65 | + "text_model.layers.{}.feed_forward.w1.weight": "decoder.layers.{}.mlp.w1.weight", |
| 66 | + "text_model.layers.{}.feed_forward.w3.weight": "decoder.layers.{}.mlp.w3.weight", |
| 67 | + "text_model.layers.{}.feed_forward.w2.weight": "decoder.layers.{}.mlp.w2.weight", |
| 68 | + |
| 69 | + "text_model.cross_attention_layers.{}.gate_attn": "decoder.layers.{}.fusion_layer.ca_scale.scale", |
| 70 | + "text_model.cross_attention_layers.{}.gate_ffwd": "decoder.layers.{}.fusion_layer.mlp_scale.scale", |
| 71 | + "text_model.cross_attention_layers.{}.attention_norm.weight": "decoder.layers.{}.fusion_layer.ca_norm.scale", |
| 72 | + "text_model.cross_attention_layers.{}.ffn_norm.weight": "decoder.layers.{}.fusion_layer.mlp_norm.scale", |
| 73 | + "text_model.cross_attention_layers.{}.attention.wq.weight": "decoder.layers.{}.fusion_layer.attn.q_proj.weight", |
| 74 | + "text_model.cross_attention_layers.{}.attention.wk.weight": "decoder.layers.{}.fusion_layer.attn.k_proj.weight", |
| 75 | + "text_model.cross_attention_layers.{}.attention.wv.weight": "decoder.layers.{}.fusion_layer.attn.v_proj.weight", |
| 76 | + "text_model.cross_attention_layers.{}.attention.wo.weight": "decoder.layers.{}.fusion_layer.attn.output_proj.weight", |
| 77 | + "text_model.cross_attention_layers.{}.attention.inner_attention.q_norm.weight": "decoder.layers.{}.fusion_layer.attn.q_norm.scale", |
| 78 | + "text_model.cross_attention_layers.{}.attention.inner_attention.k_norm.weight": "decoder.layers.{}.fusion_layer.attn.k_norm.scale", |
| 79 | + "text_model.cross_attention_layers.{}.feed_forward.w1.weight": "decoder.layers.{}.fusion_layer.mlp.w1.weight", |
| 80 | + "text_model.cross_attention_layers.{}.feed_forward.w3.weight": "decoder.layers.{}.fusion_layer.mlp.w3.weight", |
| 81 | + "text_model.cross_attention_layers.{}.feed_forward.w2.weight": "decoder.layers.{}.fusion_layer.mlp.w2.weight", |
| 82 | + |
| 83 | + "vision_model.vision_encoder.positional_embedding": "encoder.clip.token_pos_embedding.local_token_positional_embedding", |
| 84 | + "vision_model.vision_encoder.gated_positional_embedding": "encoder.clip.token_pos_embedding.global_token_positional_embedding", |
| 85 | + "vision_model.vision_encoder.gated_positional_embedding_gate": "encoder.clip.token_pos_embedding.gate", |
| 86 | + "vision_model.vision_encoder.ln_pre.weight": "encoder.clip.ln_pre.weight", |
| 87 | + "vision_model.vision_encoder.ln_pre.bias": "encoder.clip.ln_pre.bias", |
| 88 | + "vision_model.vision_encoder.ln_post.weight": "encoder.clip.ln_post.weight", |
| 89 | + "vision_model.vision_encoder.ln_post.bias": "encoder.clip.ln_post.bias", |
| 90 | + "vision_model.vision_encoder.pre_tile_pos_embed.embedding": "encoder.clip.pre_tile_pos_embed.embedding", |
| 91 | + "vision_model.vision_encoder.pre_tile_pos_embed.gate": "encoder.clip.pre_tile_pos_embed.gate", |
| 92 | + "vision_model.vision_encoder.post_tile_pos_embed.embedding": "encoder.clip.post_tile_pos_embed.embedding", |
| 93 | + "vision_model.vision_encoder.post_tile_pos_embed.gate": "encoder.clip.post_tile_pos_embed.gate", |
| 94 | + "vision_model.vision_encoder.class_embedding" : "encoder.clip.cls_token_embedding.weight", |
| 95 | + "vision_model.vision_encoder.conv1._linear.weight" : "encoder.clip.conv.weight", |
| 96 | + |
| 97 | + "vision_model.vision_encoder.transformer.resblocks.{}.attn.wq.weight": "encoder.clip.layers.{}.attn.q_proj.weight", |
| 98 | + "vision_model.vision_encoder.transformer.resblocks.{}.attn.wk.weight": "encoder.clip.layers.{}.attn.k_proj.weight", |
| 99 | + "vision_model.vision_encoder.transformer.resblocks.{}.attn.wv.weight": "encoder.clip.layers.{}.attn.v_proj.weight", |
| 100 | + "vision_model.vision_encoder.transformer.resblocks.{}.attn.wo.weight": "encoder.clip.layers.{}.attn.output_proj.weight", |
| 101 | + "vision_model.vision_encoder.transformer.resblocks.{}.mlp.c_fc.weight": "encoder.clip.layers.{}.mlp.w1.weight", |
| 102 | + "vision_model.vision_encoder.transformer.resblocks.{}.mlp.c_fc.bias": "encoder.clip.layers.{}.mlp.w1.bias", |
| 103 | + "vision_model.vision_encoder.transformer.resblocks.{}.mlp.c_proj.weight": "encoder.clip.layers.{}.mlp.w2.weight", |
| 104 | + "vision_model.vision_encoder.transformer.resblocks.{}.mlp.c_proj.bias": "encoder.clip.layers.{}.mlp.w2.bias", |
| 105 | + "vision_model.vision_encoder.transformer.resblocks.{}.ln_1.weight": "encoder.clip.layers.{}.sa_norm.weight", |
| 106 | + "vision_model.vision_encoder.transformer.resblocks.{}.ln_1.bias": "encoder.clip.layers.{}.sa_norm.bias", |
| 107 | + "vision_model.vision_encoder.transformer.resblocks.{}.ln_2.weight": "encoder.clip.layers.{}.mlp_norm.weight", |
| 108 | + "vision_model.vision_encoder.transformer.resblocks.{}.ln_2.bias": "encoder.clip.layers.{}.mlp_norm.bias", |
| 109 | + |
| 110 | + "vision_model.vision_projection.weight" : "encoder.projection.output.weight", |
| 111 | + "vision_model.vision_projection.bias" : "encoder.projection.output.bias", |
| 112 | + |
| 113 | + "vision_model.vision_encoder.global_transformer.resblocks.{}.attn.wq.weight": "encoder.projection.layers.{}.attn.q_proj.weight", |
| 114 | + "vision_model.vision_encoder.global_transformer.resblocks.{}.attn.wk.weight": "encoder.projection.layers.{}.attn.k_proj.weight", |
| 115 | + "vision_model.vision_encoder.global_transformer.resblocks.{}.attn.wv.weight": "encoder.projection.layers.{}.attn.v_proj.weight", |
| 116 | + "vision_model.vision_encoder.global_transformer.resblocks.{}.attn.wo.weight": "encoder.projection.layers.{}.attn.output_proj.weight", |
| 117 | + "vision_model.vision_encoder.global_transformer.resblocks.{}.mlp.c_fc.weight": "encoder.projection.layers.{}.mlp.w1.weight", |
| 118 | + "vision_model.vision_encoder.global_transformer.resblocks.{}.mlp.c_fc.bias": "encoder.projection.layers.{}.mlp.w1.bias", |
| 119 | + "vision_model.vision_encoder.global_transformer.resblocks.{}.mlp.c_proj.weight": "encoder.projection.layers.{}.mlp.w2.weight", |
| 120 | + "vision_model.vision_encoder.global_transformer.resblocks.{}.mlp.c_proj.bias": "encoder.projection.layers.{}.mlp.w2.bias", |
| 121 | + "vision_model.vision_encoder.global_transformer.resblocks.{}.ln_1.weight": "encoder.projection.layers.{}.sa_norm.weight", |
| 122 | + "vision_model.vision_encoder.global_transformer.resblocks.{}.ln_1.bias": "encoder.projection.layers.{}.sa_norm.bias", |
| 123 | + "vision_model.vision_encoder.global_transformer.resblocks.{}.ln_2.weight": "encoder.projection.layers.{}.mlp_norm.weight", |
| 124 | + "vision_model.vision_encoder.global_transformer.resblocks.{}.ln_2.bias": "encoder.projection.layers.{}.mlp_norm.bias", |
| 125 | + "vision_model.vision_encoder.global_transformer.resblocks.{}.gate_attn": "encoder.projection.layers.{}.sa_scale.scale", |
| 126 | + "vision_model.vision_encoder.global_transformer.resblocks.{}.gate_ffn": "encoder.projection.layers.{}.mlp_scale.scale", |
| 127 | +} |
| 128 | + |
| 129 | + |
| 130 | +def get_mapped_key(key: str, mapping_dict: Dict[str, str]) -> str: |
| 131 | + try: |
| 132 | + if any(k.isdigit() for k in key.split(".")): |
| 133 | + # Replace layer number with "{}" to create key for lookup |
| 134 | + abstract_key = re.sub(r"(\.\d+)", ".{}", key) |
| 135 | + layer_num = re.search(r"\d+", key).group(0) |
| 136 | + new_key = mapping_dict[abstract_key] |
| 137 | + new_key = new_key.format(layer_num) |
| 138 | + else: |
| 139 | + new_key = mapping_dict[key] |
| 140 | + except KeyError as e: |
| 141 | + raise Exception( |
| 142 | + f'Error converting the state dict. Found unexpected key: "{key}". ' |
| 143 | + "Please make sure you're loading a checkpoint with the right format. " |
| 144 | + ) from e |
| 145 | + |
| 146 | + return new_key |
| 147 | + |
| 148 | + |
| 149 | +def flamingo_meta_to_tune(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
| 150 | + """ |
| 151 | + Convertor from Meta state dict to torchtune state dict. This handles: |
| 152 | + - Updateing the cross attention layer numbers |
| 153 | + """ |
| 154 | + converted_state_dict = {} |
| 155 | + |
| 156 | + for key, value in state_dict.items(): |
| 157 | + if key == "text_model.rope.freqs": |
| 158 | + continue |
| 159 | + new_key = get_mapped_key(key, _FROM_META) |
| 160 | + if "cross_attention_layers" in key: |
| 161 | + layer = int(key.split(".")[2]) |
| 162 | + # TODO: grab num_layers and generalize this |
| 163 | + new_layer = (layer + 1) * 4 - 1 |
| 164 | + key_lst = new_key.split(".") |
| 165 | + key_lst[2] = str(new_layer) |
| 166 | + new_key = ".".join(key_lst) |
| 167 | + if "gate_ffwd" in key or "gate_attn" in key: |
| 168 | + value = value[:1] |
| 169 | + elif "conv1" in key: |
| 170 | + # TODO: get patch size and generalize |
| 171 | + value = value.reshape(-1, 3, 14, 14) |
| 172 | + converted_state_dict[new_key] = value |
| 173 | + return converted_state_dict |
| 174 | + |
| 175 | + |
| 176 | + |
| 177 | +if __name__ == "__main__": |
| 178 | + llava3_2_dir = str(sys.argv[1]) |
| 179 | + param_path = os.path.join(llava3_2_dir, "flamingo.json") |
| 180 | + tokenizer_path = os.path.join(llava3_2_dir, "tokenizer.model") |
| 181 | + checkpoint_path = os.path.join(llava3_2_dir, "consolidated.pth") |
| 182 | + image_path = os.path.join(llava3_2_dir, "dog.jpg") |
| 183 | + |
| 184 | + if len(sys.argv) > 2: |
| 185 | + device = torch.device(str(sys.argv[2])) |
| 186 | + elif torch.cuda.is_available(): |
| 187 | + device = torch.device('cuda:0') |
| 188 | + else: |
| 189 | + device = torch.device("cpu") |
| 190 | + |
| 191 | + print(f"Using device: {device}") |
| 192 | + print(f"Loading model from {param_path}") |
| 193 | + |
| 194 | + dtype = torch.bfloat16 |
| 195 | + with set_default_dtype(dtype), device: |
| 196 | + model = Model.from_params(param_path) |
| 197 | + |
| 198 | + transform = flamingo_transform(tokenizer_path) |
| 199 | + |
| 200 | + print(f"Loading checkpoint from {checkpoint_path}") |
| 201 | + state_dict = torch.load(checkpoint_path) |
| 202 | + print("Converting state dict into flamingo format") |
| 203 | + state_dict = flamingo_meta_to_tune(state_dict) |
| 204 | + print("Loading state dict into model") |
| 205 | + model.model.load_state_dict(state_dict) |
| 206 | + |
| 207 | + model = torch.compile(model) |
| 208 | + images = [PIL.Image.open(image_path)] |
| 209 | + |
| 210 | + dialog = [ |
| 211 | + Message( |
| 212 | + role="user", |
| 213 | + content=[ |
| 214 | + {"type": "image"}, |
| 215 | + {"type": "text", "content": "What's in this image?"}, |
| 216 | + ], |
| 217 | + eot=True, |
| 218 | + ), |
| 219 | + Message(role="assistant", content="") |
| 220 | + ] |
| 221 | + |
| 222 | + data = transform({"images": images, "messages": dialog}, inference=True) |
| 223 | + |
| 224 | + model.eval() |
| 225 | + with device: |
| 226 | + model.setup_caches(1, dtype=torch.bfloat16) |
| 227 | + |
| 228 | + |
| 229 | + max_generated_tokens = 100 |
| 230 | + temperature = .6 |
| 231 | + top_k = 500 |
| 232 | + |
| 233 | + print("Generating...") |
| 234 | + |
| 235 | + generated_tokens = [] |
| 236 | + model.reset_caches() |
| 237 | + with torch.no_grad(): |
| 238 | + batch = padded_collate([data], device, dtype) |
| 239 | + batch.pop("mask") |
| 240 | + |
| 241 | + logits = model(**batch)[:, -1] |
| 242 | + tok = sample(logits, temperature, top_k) |
| 243 | + generated_tokens.append(tok.item()) |
| 244 | + |
| 245 | + cache_mask = batch["encoder_mask"][:, -1:] |
| 246 | + for _ in range(max_generated_tokens): |
| 247 | + if tok.item() in transform.stop_tokens: |
| 248 | + break |
| 249 | + logits = model(tok, encoder_mask=cache_mask)[:, -1] |
| 250 | + tok = sample(logits, temperature, top_k) |
| 251 | + generated_tokens.append(tok.item()) |
| 252 | + |
| 253 | + print(transform.decode(generated_tokens)) |
| 254 | + |
| 255 | + |
| 256 | + |
| 257 | +""":md |
| 258 | +## Chat Pseudo Code |
| 259 | +
|
| 260 | +This approach guarantees that there's only one image cached at a time so that there's no need for cross attention masking. |
| 261 | +This works because Llama3v is trained such that each token is only allowed to attend to the previous image and the rest are |
| 262 | +masked during training/finetuning. Since consecutive images are treated as one image for Llama3v, you can control the maximum |
| 263 | +encoder sequence length by setting max_consecuitve here, as well as by settin max_num_tiles and max_resolution for the image input. |
| 264 | +
|
| 265 | +```python |
| 266 | +model.eval() |
| 267 | +model.setup_caches(1, torch.bf16) |
| 268 | +
|
| 269 | +with torch.no_grad(): |
| 270 | + # Prefill system prompt |
| 271 | + toks, _ = transform(parse_prompt(system_prompt)) |
| 272 | + model(toks) |
| 273 | + while True: |
| 274 | + # Prefill user prompt split over images |
| 275 | + user_prompt = input(">>> ") |
| 276 | + toks, imgs = transform(parse_prompt(user_prompt)) |
| 277 | + for i, tok in enumerate(split(toks, image_token, max_consecutive=1)): |
| 278 | + img = None |
| 279 | + if imgs is not None: |
| 280 | + img = imgs[i] |
| 281 | + reset_attn_cache(model) |
| 282 | + logits = model(tok, img) |
| 283 | +
|
| 284 | + # Decode assitant response |
| 285 | + tok = sample_tok(logits) # only ouptput single token logits when model.cache_enabled=True |
| 286 | + while tok != EOS: |
| 287 | + logits = model(tok) |
| 288 | + tok = sample_tok(logits) |
| 289 | + sys.stdout.buffer.write(transform.decode(tok)) |
| 290 | +``` |
| 291 | +""" |
| 292 | + |
| 293 | +""":py""" |
0 commit comments