diff --git a/torchtitan/experiments/__init__.py b/torchtitan/experiments/__init__.py index 9d81f6b88..d11ef99d8 100644 --- a/torchtitan/experiments/__init__.py +++ b/torchtitan/experiments/__init__.py @@ -7,3 +7,4 @@ import torchtitan.experiments.llama4 # noqa: F401 import torchtitan.experiments.qwen3 import torchtitan.experiments.simple_fsdp # noqa: F401 +import torchtitan.experiments.vlm # noqa: F401 diff --git a/torchtitan/experiments/vlm/README.md b/torchtitan/experiments/vlm/README.md new file mode 100644 index 000000000..f160d0205 --- /dev/null +++ b/torchtitan/experiments/vlm/README.md @@ -0,0 +1,58 @@ +# Vision Language Model training in `torchtitan` + +**under active development** + +This folder showcases how to train modern Vision Language Model (vlm) in torchtitan. + + +## Features: +- Native Aspect Ratio: not limited to square crops. +- Native Resolution: images in a batch can have different sizes, no more image tiles and thumbnails. +- Native Interleaved data: training samples can have variable number of images, interleaved with text at different position. You can train more than just a captioning model. + + +## Design +Distributed training usually does not play nice with input of varying shapes. To handle a varying number of images and image sizes, we requires two hyperparameters, image batch size `N` and image length `L` (in patches), and pad the actual image patches to this fixed size. +Then we scatter the patch embeddings to their actual positions in the LLM input tokens. + +Screenshot 2025-08-21 at 16 21 57 + +- After `tok_embedding`, we obtain tokens of shape `BxS`. +- After `encoder`, we obtain visual tokens of shape `NxL`. +- We extract the valid visual tokens only +- Then scatter those tokens to their actual positions in the LLM input tokens. + + +This result in a very simple and general interface to train modern VLM with interleaved data and native resolution & aspect ratio: +- Depending on data mixtures, we can set dataloader's hyperparameters `N, L` to have minimal empty image padding (in batch dimension). +- Use modern pytorch features (FlexAttention, compile etc) for efficient handling of different attention mask per (padding in sequence dimension). +- Interface nicely with TP, PP, etc + + +## Implementation + +### Dataloader +This approach requires the dataloader to handle the following aspect: +- [x] Interleave the correct precise numbers of image tokens in the inputs token based on encoder's patch size and input images' size +- [x] Convert images/videos to 1D sequence of patchs: + - `rearrange(pixels, 'n (t pt) (h ph) (w pw) c -> n (t h w) (pt p pw c)', pt=temporal_ps, ph=patch_size, pw=patch_size)` + - Pad all image patches sequence to a fixed length and return `pixel_values.shape == [N, L, D]` +- [x] Return a `grid_thw.shape == [N, L, 3]` to keep track of the location indicies of each patches in the images. Padding image can be tracked in the same tensors with values `-1`. +- [x] LLM Sample / Document Packing. +- [x] Captioning dataset: CC12M +- [x] Interleaved dataset: Obelics + + + +### Model +We also need Ar pretrained vision encoder with support for native resolution and aspect ratio. There is relatively few Vision Encoder that have this capability up until recently, including Siglip2, AimV2, and most recently DINOv3. +- [ ] Currently we support Siglip2 encoder using Positional Embedding interpolation approach. + - [x] Base modelling code. + - [ ] Weights conversion and loading from HF. +- [x] FSDP for both Encoder and Decoder +- [x] Context Parallel for LLM only, since we will use FlexAttention for Encoder. +- [ ] FlexAttention for with different seq len per image. +- [ ] Compile for Encoder + Deocoder +- [ ] Tensor Parallel +- [ ] Pipeline Parallel + diff --git a/torchtitan/experiments/vlm/__init__.py b/torchtitan/experiments/vlm/__init__.py new file mode 100644 index 000000000..244a5231c --- /dev/null +++ b/torchtitan/experiments/vlm/__init__.py @@ -0,0 +1,101 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.components.validate import build_validator +from torchtitan.protocols.train_spec import register_train_spec, TrainSpec + +from .datasets.mm_datasets import build_mm_dataloader +from .infra.parallelize import parallelize_vlm +# from .infra.pipeline import pipeline_llama +from .model.args import Llama3Siglip2ModelArgs, Siglip2ModelArgs +from .model.model import Llama3Siglip2Transformer + +__all__ = [ + "parallelize_vlm", + # "pipeline_llama", + "Llama3Siglip2ModelArgs", + "Llama3Siglip2Transformer", + "llama3_siglip2_configs", +] + + +siglip2_configs = { + "debugmodel": Siglip2ModelArgs( + dim=128, + ffn_dim=256, + n_layers=4, + n_heads=2, + ) +} + +llama3_siglip2_configs = { + "debugmodel": Llama3Siglip2ModelArgs( + encoder=siglip2_configs["debugmodel"], + dim=256, n_layers=6, n_heads=16, vocab_size=2048, rope_theta=500000 + ), + "debugmodel_flex_attn": Llama3Siglip2ModelArgs( + encoder=siglip2_configs["debugmodel"], + dim=256, + n_layers=6, + n_heads=16, + vocab_size=2000, + rope_theta=500000, + use_flex_attn=True, + attn_mask_type="block_causal", + ), + "8B": Llama3Siglip2ModelArgs( + encoder=siglip2_configs["debugmodel"], + dim=4096, + n_layers=32, + n_heads=32, + n_kv_heads=8, + ffn_dim_multiplier=1.3, + multiple_of=1024, + rope_theta=500000, + ), + "70B": Llama3Siglip2ModelArgs( + encoder=siglip2_configs["debugmodel"], + dim=8192, + n_layers=80, + n_heads=64, + n_kv_heads=8, + ffn_dim_multiplier=1.3, + multiple_of=4096, + rope_theta=500000, + ), + "405B": Llama3Siglip2ModelArgs( + encoder=siglip2_configs["debugmodel"], + dim=16384, + n_layers=126, + n_heads=128, + n_kv_heads=8, + ffn_dim_multiplier=1.2, + multiple_of=4096, + rope_theta=500000, + ), +} + + +register_train_spec( + TrainSpec( + name="llama3-siglip2", + model_cls=Llama3Siglip2Transformer, + model_args=llama3_siglip2_configs, + parallelize_fn=parallelize_vlm, + pipelining_fn=None, + build_optimizers_fn=build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_mm_dataloader, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + build_validator_fn=build_validator, + # state_dict_adapter=Llama3StateDictAdapter, + ) +) diff --git a/torchtitan/experiments/vlm/assets/job_config.py b/torchtitan/experiments/vlm/assets/job_config.py new file mode 100644 index 000000000..8f206c466 --- /dev/null +++ b/torchtitan/experiments/vlm/assets/job_config.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + + +@dataclass +class Training: + max_images_per_batch: int = 10 + """Vision encoder batch size (N)""" + max_patches_per_image: int = 256 + """Vision encoder sequence length (L)""" + patch_size: int = 16 + """ Patch size of the vision encoder. + For example, image size 256x256, patch size 16 + Number of visual tokens is: (256/16)**2=256 + """ + spatial_merge_size: int = 1 + """ Spatially merge visual tokens after encoder. + For example: image size 256x256, patch size 16, spaitl merge size is 2 + Number of visual tokens for the LLM: (256/16/2)**2 = 8 + """ + packing_buffer_size: int = 0 + """ Set to a value >0 to enable sample packing. + This control the buffer uses to store training samples avaliable for packing. + """ + + +# HACK: couldn't figure out how to modify the HF tokenizer's json +# to make these attribute accesible. Ideally these should be accesible from the tokenizer itself. +@dataclass +class SpecialTokens: + img_token: str = "<|image|>" + boi_token: str = "<|begin_of_image|>" + eoi_token: str = "<|end_of_image|>" + img_id: int = 1998 + boi_id: int = 1999 + eoi_id: int = 2000 + pad_id: int = 2001 + + +@dataclass +class JobConfig: + training: Training = field(default_factory=Training) + special_tokens: SpecialTokens = field(default_factory=SpecialTokens) diff --git a/torchtitan/experiments/vlm/assets/tokenizer/tokenizer.json b/torchtitan/experiments/vlm/assets/tokenizer/tokenizer.json new file mode 100644 index 000000000..7eed211ed --- /dev/null +++ b/torchtitan/experiments/vlm/assets/tokenizer/tokenizer.json @@ -0,0 +1,2041 @@ +{ + "version": "1.0", + "truncation": null, + "padding": null, + "added_tokens": [], + "normalizer": null, + "pre_tokenizer": { + "type": "ByteLevel", + "add_prefix_space": false, + "trim_offsets": true, + "use_regex": true + }, + "post_processor": { + "type": "ByteLevel", + "add_prefix_space": true, + "trim_offsets": true, + "use_regex": true + }, + "decoder": { + "type": "ByteLevel", + "add_prefix_space": true, + "trim_offsets": true, + "use_regex": true + }, + "model": { + "type": "BPE", + "dropout": null, + "unk_token": null, + "continuing_subword_prefix": null, + "end_of_word_suffix": null, + "fuse_unk": false, + "byte_fallback": false, + "ignore_merges": false, + "vocab": { + "\u0000": 0, + "\u0001": 1, + "\u0002": 2, + "\u0003": 3, + "\u0004": 4, + "\u0005": 5, + "\u0006": 6, + "\u0007": 7, + "\b": 8, + "\t": 9, + "\n": 10, + "\u000b": 11, + "\f": 12, + "\r": 13, + "\u000e": 14, + "\u000f": 15, + "\u0010": 16, + "\u0011": 17, + "\u0012": 18, + "\u0013": 19, + "\u0014": 20, + "\u0015": 21, + "\u0016": 22, + "\u0017": 23, + "\u0018": 24, + "\u0019": 25, + "\u001a": 26, + "\u001b": 27, + "\u001c": 28, + "\u001d": 29, + "\u001e": 30, + "\u001f": 31, + " ": 32, + "!": 33, + "\"": 34, + "#": 35, + "$": 36, + "%": 37, + "&": 38, + "'": 39, + "(": 40, + ")": 41, + "*": 42, + "+": 43, + ",": 44, + "-": 45, + ".": 46, + "/": 47, + "0": 48, + "1": 49, + "2": 50, + "3": 51, + "4": 52, + "5": 53, + "6": 54, + "7": 55, + "8": 56, + "9": 57, + ":": 58, + ";": 59, + "<": 60, + "=": 61, + ">": 62, + "?": 63, + "@": 64, + "A": 65, + "B": 66, + "C": 67, + "D": 68, + "E": 69, + "F": 70, + "G": 71, + "H": 72, + "I": 73, + "J": 74, + "K": 75, + "L": 76, + "M": 77, + "N": 78, + "O": 79, + "P": 80, + "Q": 81, + "R": 82, + "S": 83, + "T": 84, + "U": 85, + "V": 86, + "W": 87, + "X": 88, + "Y": 89, + "Z": 90, + "[": 91, + "\\": 92, + "]": 93, + "^": 94, + "_": 95, + "`": 96, + "a": 97, + "b": 98, + "c": 99, + "d": 100, + "e": 101, + "f": 102, + "g": 103, + "h": 104, + "i": 105, + "j": 106, + "k": 107, + "l": 108, + "m": 109, + "n": 110, + "o": 111, + "p": 112, + "q": 113, + "r": 114, + "s": 115, + "t": 116, + "u": 117, + "v": 118, + "w": 119, + "x": 120, + "y": 121, + "z": 122, + "{": 123, + "|": 124, + "}": 125, + "~": 126, + "": 127, + "\\x80": 128, + "\\x81": 129, + "\\x82": 130, + "\\x83": 131, + "\\x84": 132, + "\\x85": 133, + "\\x86": 134, + "\\x87": 135, + "\\x88": 136, + "\\x89": 137, + "\\x8a": 138, + "\\x8b": 139, + "\\x8c": 140, + "\\x8d": 141, + "\\x8e": 142, + "\\x8f": 143, + "\\x90": 144, + "\\x91": 145, + "\\x92": 146, + "\\x93": 147, + "\\x94": 148, + "\\x95": 149, + "\\x96": 150, + "\\x97": 151, + "\\x98": 152, + "\\x99": 153, + "\\x9a": 154, + "\\x9b": 155, + "\\x9c": 156, + "\\x9d": 157, + "\\x9e": 158, + "\\x9f": 159, + "\\xa0": 160, + "\\xa1": 161, + "\\xa2": 162, + "\\xa3": 163, + "\\xa4": 164, + "\\xa5": 165, + "\\xa6": 166, + "\\xa7": 167, + "\\xa8": 168, + "\\xa9": 169, + "\\xaa": 170, + "\\xab": 171, + "\\xac": 172, + "\\xad": 173, + "\\xae": 174, + "\\xaf": 175, + "\\xb0": 176, + "\\xb1": 177, + "\\xb2": 178, + "\\xb3": 179, + "\\xb4": 180, + "\\xb5": 181, + "\\xb6": 182, + "\\xb7": 183, + "\\xb8": 184, + "\\xb9": 185, + "\\xba": 186, + "\\xbb": 187, + "\\xbc": 188, + "\\xbd": 189, + "\\xbe": 190, + "\\xbf": 191, + "\\xc0": 192, + "\\xc1": 193, + "\\xc2": 194, + "\\xc3": 195, + "\\xc4": 196, + "\\xc5": 197, + "\\xc6": 198, + "\\xc7": 199, + "\\xc8": 200, + "\\xc9": 201, + "\\xca": 202, + "\\xcb": 203, + "\\xcc": 204, + "\\xcd": 205, + "\\xce": 206, + "\\xcf": 207, + "\\xd0": 208, + "\\xd1": 209, + "\\xd2": 210, + "\\xd3": 211, + "\\xd4": 212, + "\\xd5": 213, + "\\xd6": 214, + "\\xd7": 215, + "\\xd8": 216, + "\\xd9": 217, + "\\xda": 218, + "\\xdb": 219, + "\\xdc": 220, + "\\xdd": 221, + "\\xde": 222, + "\\xdf": 223, + "\\xe0": 224, + "\\xe1": 225, + "\\xe2": 226, + "\\xe3": 227, + "\\xe4": 228, + "\\xe5": 229, + "\\xe6": 230, + "\\xe7": 231, + "\\xe8": 232, + "\\xe9": 233, + "\\xea": 234, + "\\xeb": 235, + "\\xec": 236, + "\\xed": 237, + "\\xee": 238, + "\\xef": 239, + "\\xf0": 240, + "\\xf1": 241, + "\\xf2": 242, + "\\xf3": 243, + "\\xf4": 244, + "\\xf5": 245, + "\\xf6": 246, + "\\xf7": 247, + "\\xf8": 248, + "\\xf9": 249, + "\\xfa": 250, + "\\xfb": 251, + "\\xfc": 252, + "\\xfd": 253, + "\\xfe": 254, + "\\xff": 255, + " t": 256, + "he": 257, + " a": 258, + "in": 259, + " s": 260, + " w": 261, + " the": 262, + " o": 263, + "re": 264, + " b": 265, + "ou": 266, + "ed": 267, + " m": 268, + "nd": 269, + " I": 270, + "ha": 271, + "it": 272, + "er": 273, + "ing": 274, + " f": 275, + "is": 276, + " to": 277, + "en": 278, + "on": 279, + "or": 280, + "as": 281, + " c": 282, + " of": 283, + " and": 284, + " d": 285, + "ll": 286, + "at": 287, + "an": 288, + "ar": 289, + " p": 290, + " n": 291, + " in": 292, + "le": 293, + "om": 294, + "ot": 295, + " be": 296, + " h": 297, + "ut": 298, + "ow": 299, + "es": 300, + "hat": 301, + " g": 302, + " he": 303, + " ha": 304, + " l": 305, + " was": 306, + "ld": 307, + "gh": 308, + "id": 309, + "ch": 310, + " th": 311, + " it": 312, + "ay": 313, + " on": 314, + "ce": 315, + "se": 316, + "ent": 317, + " st": 318, + "ly": 319, + "ve": 320, + "et": 321, + "st": 322, + " T": 323, + " e": 324, + " y": 325, + "ght": 326, + "ir": 327, + " me": 328, + "oo": 329, + "al": 330, + "ith": 331, + " re": 332, + "im": 333, + " that": 334, + " as": 335, + "ould": 336, + "ro": 337, + "ad": 338, + "ion": 339, + ".\n": 340, + "her": 341, + " my": 342, + "ct": 343, + " not": 344, + " with": 345, + " for": 346, + " u": 347, + "ke": 348, + " you": 349, + " S": 350, + " is": 351, + "ight": 352, + "\"\n": 353, + "am": 354, + "ic": 355, + "ur": 356, + " at": 357, + "..": 358, + "ac": 359, + "ter": 360, + " wh": 361, + " an": 362, + " we": 363, + " The": 364, + "if": 365, + " or": 366, + " but": 367, + "ver": 368, + " \"": 369, + " r": 370, + "out": 371, + "ome": 372, + " had": 373, + "pp": 374, + "qu": 375, + " su": 376, + " this": 377, + "red": 378, + "ard": 379, + " so": 380, + "ell": 381, + " would": 382, + " his": 383, + " sh": 384, + "ine": 385, + "ra": 386, + " se": 387, + " by": 388, + ".\"\n": 389, + " P": 390, + "hen": 391, + " A": 392, + " have": 393, + " fr": 394, + " sa": 395, + " H": 396, + " one": 397, + "em": 398, + "ked": 399, + "irt": 400, + "ect": 401, + " him": 402, + " li": 403, + " ab": 404, + "ation": 405, + "hing": 406, + "the": 407, + " R": 408, + " le": 409, + "ss": 410, + " W": 411, + "cu": 412, + "ill": 413, + "'t": 414, + "art": 415, + "all": 416, + ",\n": 417, + "own": 418, + "ore": 419, + " all": 420, + " k": 421, + " go": 422, + "hirt": 423, + "and": 424, + " out": 425, + "ame": 426, + "ain": 427, + " if": 428, + " no": 429, + " do": 430, + " they": 431, + "ool": 432, + "un": 433, + "to": 434, + " up": 435, + " Red": 436, + " ne": 437, + " K": 438, + " from": 439, + " Shirt": 440, + " wor": 441, + "ong": 442, + " there": 443, + " said": 444, + "ri": 445, + "ant": 446, + " B": 447, + " any": 448, + "ud": 449, + "ind": 450, + " whi": 451, + "ab": 452, + "ound": 453, + " about": 454, + " them": 455, + "cup": 456, + "ak": 457, + " de": 458, + " te": 459, + " M": 460, + "ake": 461, + "cupine": 462, + "ig": 463, + " were": 464, + "orcupine": 465, + "il": 466, + "chool": 467, + " ro": 468, + "ood": 469, + " are": 470, + "ive": 471, + " like": 472, + "yo": 473, + " hou": 474, + "'s": 475, + "one": 476, + "us": 477, + "el": 478, + "ul": 479, + "ack": 480, + "op": 481, + ",\"": 482, + "th": 483, + "acher": 484, + "um": 485, + "ang": 486, + " fa": 487, + "ag": 488, + " school": 489, + " j": 490, + "te": 491, + "ok": 492, + "ess": 493, + "ust": 494, + "ers": 495, + "....": 496, + " C": 497, + "ther": 498, + "han": 499, + " when": 500, + " sp": 501, + " man": 502, + " can": 503, + "ough": 504, + " who": 505, + " get": 506, + " did": 507, + " po": 508, + "ci": 509, + " al": 510, + "ist": 511, + " com": 512, + "lf": 513, + "au": 514, + " Porcupine": 515, + " which": 516, + "ven": 517, + " af": 518, + "wn": 519, + "ass": 520, + "ber": 521, + " ex": 522, + "ous": 523, + "est": 524, + "lo": 525, + " tr": 526, + "ellow": 527, + " say": 528, + "ought": 529, + " room": 530, + " some": 531, + "--": 532, + " O": 533, + "ate": 534, + " v": 535, + "hed": 536, + "ap": 537, + " tw": 538, + " bec": 539, + "ree": 540, + "ject": 541, + "ks": 542, + " con": 543, + " been": 544, + "ents": 545, + "ide": 546, + " could": 547, + " G": 548, + "ep": 549, + " pro": 550, + "nt": 551, + " house": 552, + " ag": 553, + " If": 554, + " kn": 555, + " fellow": 556, + " what": 557, + "way": 558, + "ish": 559, + " am": 560, + "ite": 561, + "nder": 562, + "ime": 563, + " pr": 564, + " teacher": 565, + "are": 566, + " bo": 567, + " she": 568, + " N": 569, + "ice": 570, + "ast": 571, + "ure": 572, + "ie": 573, + " such": 574, + "uten": 575, + "utenber": 576, + "utenberg": 577, + " qu": 578, + "lown": 579, + " wr": 580, + "pt": 581, + " He": 582, + " stud": 583, + "here": 584, + " more": 585, + "ry": 586, + "tter": 587, + " Y": 588, + " may": 589, + "ity": 590, + " loo": 591, + " other": 592, + "his": 593, + " Pro": 594, + " will": 595, + " It": 596, + "ort": 597, + " should": 598, + "very": 599, + "we": 600, + " pl": 601, + "ash": 602, + ".\"": 603, + " app": 604, + " day": 605, + "urn": 606, + "po": 607, + " her": 608, + " ": 609, + "not": 610, + "ck": 611, + " un": 612, + "hi": 613, + "ving": 614, + " old": 615, + " time": 616, + "\"T": 617, + " way": 618, + "able": 619, + "?\"\n": 620, + " Clown": 621, + " only": 622, + "ub": 623, + "ach": 624, + " off": 625, + " than": 626, + "ally": 627, + " their": 628, + "be": 629, + "king": 630, + "other": 631, + "ary": 632, + "ans": 633, + "ated": 634, + "self": 635, + " going": 636, + "uch": 637, + "oll": 638, + " back": 639, + "iyo": 640, + "-t": 641, + "ance": 642, + "ade": 643, + " Project": 644, + "sp": 645, + " two": 646, + " thought": 647, + "so": 648, + " right": 649, + " head": 650, + "ved": 651, + " D": 652, + " pre": 653, + " see": 654, + " us": 655, + " students": 656, + "cip": 657, + " don": 658, + " night": 659, + "incip": 660, + " Kiyo": 661, + "pl": 662, + "ared": 663, + " Gutenberg": 664, + " co": 665, + " how": 666, + "omet": 667, + "ff": 668, + "\"I": 669, + ",--": 670, + " asked": 671, + "incipal": 672, + "ever": 673, + " ac": 674, + " F": 675, + " make": 676, + "itt": 677, + " might": 678, + "ge": 679, + "led": 680, + " after": 681, + "ign": 682, + " gr": 683, + " made": 684, + "dd": 685, + " know": 686, + " come": 687, + " br": 688, + "thing": 689, + " But": 690, + " mat": 691, + " On": 692, + "ory": 693, + "cl": 694, + " E": 695, + "ble": 696, + "og": 697, + " your": 698, + "ull": 699, + " work": 700, + "ear": 701, + " three": 702, + "ied": 703, + "but": 704, + "The": 705, + "pe": 706, + "ace": 707, + " start": 708, + "ick": 709, + " over": 710, + "our": 711, + " much": 712, + " want": 713, + "imp": 714, + " part": 715, + "ho": 716, + "ink": 717, + "ence": 718, + " down": 719, + " even": 720, + " principal": 721, + "ling": 722, + "ount": 723, + "ause": 724, + " cl": 725, + " bl": 726, + "-tm": 727, + "omething": 728, + " into": 729, + "orm": 730, + "okyo": 731, + " dis": 732, + " fe": 733, + " face": 734, + "......": 735, + "ress": 736, + "ment": 737, + "ire": 738, + " ar": 739, + "ty": 740, + " mo": 741, + "reat": 742, + " fir": 743, + "per": 744, + " our": 745, + "co": 746, + " then": 747, + " told": 748, + "ings": 749, + " take": 750, + " beg": 751, + "ner": 752, + "ition": 753, + "ose": 754, + " own": 755, + " again": 756, + " seem": 757, + "ise": 758, + " wat": 759, + "\"W": 760, + " far": 761, + "aking": 762, + "fore": 763, + "ady": 764, + "-s": 765, + "less": 766, + " ret": 767, + " sha": 768, + " came": 769, + "ger": 770, + " good": 771, + "ather": 772, + "ark": 773, + "row": 774, + " ke": 775, + "'m": 776, + " has": 777, + "ath": 778, + "pped": 779, + " went": 780, + " tell": 781, + "quash": 782, + " en": 783, + " first": 784, + " hot": 785, + "iz": 786, + " away": 787, + " something": 788, + " rem": 789, + " town": 790, + " sm": 791, + " This": 792, + " better": 793, + " Then": 794, + "was": 795, + "of": 796, + "bard": 797, + " L": 798, + "li": 799, + "fe": 800, + " Tokyo": 801, + " long": 802, + "ily": 803, + " sure": 804, + " looked": 805, + "ubbard": 806, + "ction": 807, + "ord": 808, + " many": 809, + "ious": 810, + " too": 811, + " here": 812, + "os": 813, + " under": 814, + "ase": 815, + "ng": 816, + "ped": 817, + "od": 818, + "me": 819, + " just": 820, + " now": 821, + "ince": 822, + " heard": 823, + " kind": 824, + " They": 825, + " before": 826, + "hy": 827, + " In": 828, + " ent": 829, + " board": 830, + "!\"": 831, + "ward": 832, + " being": 833, + " well": 834, + "erm": 835, + "ried": 836, + " wrong": 837, + "aid": 838, + "xt": 839, + " return": 840, + "ited": 841, + " yen": 842, + " matter": 843, + " call": 844, + " tal": 845, + " You": 846, + "ced": 847, + "ised": 848, + " cha": 849, + "ons": 850, + " same": 851, + " once": 852, + "day": 853, + "ft": 854, + " sw": 855, + " because": 856, + " think": 857, + " where": 858, + " No": 859, + " Hubbard": 860, + " Squash": 861, + " cop": 862, + "with": 863, + "ered": 864, + "ollow": 865, + " place": 866, + "idd": 867, + "cess": 868, + " show": 869, + "isha": 870, + " ra": 871, + " letter": 872, + "ne": 873, + "ves": 874, + "ating": 875, + "rang": 876, + " aff": 877, + " hand": 878, + " sc": 879, + " pers": 880, + "int": 881, + "pr": 882, + "side": 883, + "fter": 884, + " saying": 885, + " lau": 886, + "that": 887, + " without": 888, + "ron": 889, + "air": 890, + "lect": 891, + " What": 892, + "elt": 893, + " while": 894, + "oga": 895, + "aper": 896, + " pe": 897, + "oy": 898, + " sat": 899, + "ies": 900, + " add": 901, + " days": 902, + " spe": 903, + " ho": 904, + " ans": 905, + " har": 906, + " When": 907, + " anything": 908, + "pen": 909, + "]\n": 910, + "tain": 911, + " must": 912, + " new": 913, + "lic": 914, + " vo": 915, + "hile": 916, + "get": 917, + " As": 918, + " very": 919, + "'re": 920, + " every": 921, + "ave": 922, + "?\"": 923, + "adger": 924, + " Koga": 925, + " Mr": 926, + "rough": 927, + "ult": 928, + " follow": 929, + "ting": 930, + "ife": 931, + "iddle": 932, + "ful": 933, + "ank": 934, + " So": 935, + " seemed": 936, + " And": 937, + "ix": 938, + " set": 939, + " care": 940, + " res": 941, + " never": 942, + " found": 943, + " lo": 944, + "cid": 945, + "ined": 946, + " class": 947, + " myself": 948, + "aw": 949, + " wom": 950, + "ations": 951, + " left": 952, + " We": 953, + " teachers": 954, + "\"Y": 955, + "na": 956, + "ont": 957, + " des": 958, + " those": 959, + "ired": 960, + " sen": 961, + "ying": 962, + " these": 963, + "az": 964, + " There": 965, + "cept": 966, + " dang": 967, + " U": 968, + "\"H": 969, + "bod": 970, + "body": 971, + " having": 972, + "alary": 973, + " watch": 974, + " give": 975, + "age": 976, + " its": 977, + " appe": 978, + "ue": 979, + " count": 980, + " hard": 981, + " bel": 982, + "ott": 983, + " dist": 984, + "\"S": 985, + " Mad": 986, + "-n": 987, + "ribut": 988, + "ged": 989, + " att": 990, + "fere": 991, + "ither": 992, + " upon": 993, + " tem": 994, + " person": 995, + "ning": 996, + " che": 997, + "arly": 998, + "oney": 999, + " soon": 1000, + "ement": 1001, + " (": 1002, + " trans": 1003, + " exp": 1004, + " ser": 1005, + " reg": 1006, + "ason": 1007, + " saw": 1008, + " next": 1009, + "oot": 1010, + " half": 1011, + " took": 1012, + " bad": 1013, + " hour": 1014, + " salary": 1015, + " began": 1016, + "right": 1017, + "onna": 1018, + "-san": 1019, + " works": 1020, + " J": 1021, + "form": 1022, + "ical": 1023, + " tra": 1024, + "man": 1025, + " nothing": 1026, + " still": 1027, + "ears": 1028, + " supp": 1029, + " turn": 1030, + " felt": 1031, + " woman": 1032, + " started": 1033, + "ouble": 1034, + "ura": 1035, + "ishing": 1036, + ":\n": 1037, + "lectron": 1038, + "lectronic": 1039, + "ook": 1040, + " copy": 1041, + " full": 1042, + "cond": 1043, + "mat": 1044, + " middle": 1045, + " look": 1046, + " comm": 1047, + "wered": 1048, + " became": 1049, + " fellows": 1050, + "would": 1051, + " got": 1052, + " gl": 1053, + " gu": 1054, + " keep": 1055, + " ge": 1056, + " Madonna": 1057, + "iter": 1058, + "ished": 1059, + " underst": 1060, + " stra": 1061, + "sid": 1062, + " country": 1063, + "ople": 1064, + " prov": 1065, + " put": 1066, + "no": 1067, + "'ll": 1068, + " sle": 1069, + "range": 1070, + " She": 1071, + "pos": 1072, + " mind": 1073, + " pass": 1074, + " through": 1075, + " quite": 1076, + " ind": 1077, + " boarding": 1078, + "teacher": 1079, + "ple": 1080, + "Porcupine": 1081, + " ple": 1082, + " geisha": 1083, + " ": 1084, + "ost": 1085, + "ense": 1086, + "No": 1087, + "ible": 1088, + " read": 1089, + " red": 1090, + "ention": 1091, + "ened": 1092, + "!\"\n": 1093, + " ref": 1094, + " ad": 1095, + " fl": 1096, + " stay": 1097, + "up": 1098, + " round": 1099, + " cle": 1100, + " open": 1101, + " ob": 1102, + "tend": 1103, + " find": 1104, + " per": 1105, + " called": 1106, + " sur": 1107, + "rew": 1108, + " paper": 1109, + " Badger": 1110, + " meet": 1111, + "iss": 1112, + "\"That": 1113, + "erms": 1114, + "TE": 1115, + "itten": 1116, + "ably": 1117, + "ness": 1118, + " cannot": 1119, + " simp": 1120, + "con": 1121, + " reason": 1122, + "you": 1123, + " home": 1124, + "by": 1125, + " fight": 1126, + "ittle": 1127, + " things": 1128, + " eas": 1129, + " imp": 1130, + "ressed": 1131, + " mean": 1132, + " appeared": 1133, + " nat": 1134, + " hel": 1135, + "ret": 1136, + "aken": 1137, + " straight": 1138, + " affair": 1139, + "iting": 1140, + " ed": 1141, + " since": 1142, + "log": 1143, + " pay": 1144, + " front": 1145, + "my": 1146, + " voice": 1147, + "ready": 1148, + " fool": 1149, + "oundation": 1150, + " electronic": 1151, + " terms": 1152, + " mar": 1153, + "apan": 1154, + "any": 1155, + " resp": 1156, + " end": 1157, + "app": 1158, + "what": 1159, + "str": 1160, + "rap": 1161, + "ial": 1162, + "icul": 1163, + " acc": 1164, + "oth": 1165, + " second": 1166, + " flo": 1167, + " six": 1168, + " feet": 1169, + "br": 1170, + "iet": 1171, + " little": 1172, + "les": 1173, + " money": 1174, + " decl": 1175, + " ey": 1176, + " comp": 1177, + "aring": 1178, + " agre": 1179, + "where": 1180, + " St": 1181, + " stre": 1182, + "ex": 1183, + "ract": 1184, + " int": 1185, + " dire": 1186, + " become": 1187, + " hon": 1188, + " consid": 1189, + "ertain": 1190, + "now": 1191, + " sl": 1192, + "itor": 1193, + "gg": 1194, + " jum": 1195, + " bu": 1196, + " thing": 1197, + " answered": 1198, + "oes": 1199, + "ya": 1200, + " That": 1201, + "ize": 1202, + "ond": 1203, + "act": 1204, + " eff": 1205, + " bang": 1206, + "about": 1207, + " bed": 1208, + "orrow": 1209, + "ung": 1210, + " To": 1211, + " kept": 1212, + " wal": 1213, + " bath": 1214, + " dra": 1215, + "\"A": 1216, + "rings": 1217, + "hopp": 1218, + " resign": 1219, + " din": 1220, + " lady": 1221, + ".E": 1222, + " use": 1223, + "lish": 1224, + "ors": 1225, + " written": 1226, + "ene": 1227, + "iv": 1228, + " dif": 1229, + " ste": 1230, + " story": 1231, + "com": 1232, + "res": 1233, + "ently": 1234, + " fact": 1235, + "hes": 1236, + "ways": 1237, + " why": 1238, + " though": 1239, + " str": 1240, + "onder": 1241, + "head": 1242, + " cour": 1243, + " mon": 1244, + " sk": 1245, + " belie": 1246, + " let": 1247, + "fer": 1248, + " requ": 1249, + " line": 1250, + "room": 1251, + "-day": 1252, + " done": 1253, + " does": 1254, + " One": 1255, + " dango": 1256, + "asshopp": 1257, + " consider": 1258, + " dinner": 1259, + " Foundation": 1260, + "**": 1261, + "empt": 1262, + "ese": 1263, + " word": 1264, + "rest": 1265, + " enough": 1266, + " great": 1267, + " name": 1268, + " pub": 1269, + " manner": 1270, + "wer": 1271, + "ict": 1272, + "iness": 1273, + " himself": 1274, + " people": 1275, + "ew": 1276, + " cor": 1277, + "estion": 1278, + " big": 1279, + "ee": 1280, + " ri": 1281, + "ides": 1282, + " brother": 1283, + " heart": 1284, + "ected": 1285, + "eed": 1286, + " others": 1287, + "sol": 1288, + "ted": 1289, + " eyes": 1290, + " trouble": 1291, + " teach": 1292, + " boat": 1293, + " four": 1294, + " already": 1295, + "rom": 1296, + "ghed": 1297, + " squ": 1298, + " pol": 1299, + "ces": 1300, + " Hott": 1301, + " leave": 1302, + " distribut": 1303, + "aster": 1304, + "CH": 1305, + "uc": 1306, + " im": 1307, + " however": 1308, + "there": 1309, + "apanese": 1310, + " last": 1311, + " cr": 1312, + "ility": 1313, + " simple": 1314, + " life": 1315, + "-c": 1316, + " regard": 1317, + " fin": 1318, + "ual": 1319, + " means": 1320, + " stand": 1321, + "atch": 1322, + " short": 1323, + "ned": 1324, + " seen": 1325, + " happ": 1326, + "-k": 1327, + " against": 1328, + "him": 1329, + "amed": 1330, + " stood": 1331, + " gra": 1332, + " mother": 1333, + " fish": 1334, + " water": 1335, + "ail": 1336, + "cei": 1337, + " rather": 1338, + " ins": 1339, + " feel": 1340, + " also": 1341, + " ord": 1342, + " coming": 1343, + "ics": 1344, + " either": 1345, + "nce": 1346, + " '": 1347, + " kid": 1348, + " laughed": 1349, + "like": 1350, + " Ar": 1351, + "gr": 1352, + " Hotta": 1353, + " talk": 1354, + "gether": 1355, + " Sir": 1356, + " pun": 1357, + "Pro": 1358, + "ats": 1359, + "most": 1360, + " rep": 1361, + " gi": 1362, + "isf": 1363, + "bably": 1364, + "akes": 1365, + " Not": 1366, + "ny": 1367, + " appear": 1368, + "mp": 1369, + "cha": 1370, + " act": 1371, + "bed": 1372, + "ief": 1373, + "uff": 1374, + " apo": 1375, + " met": 1376, + " returned": 1377, + " sound": 1378, + "usiness": 1379, + " laugh": 1380, + " clear": 1381, + " need": 1382, + "fess": 1383, + "ested": 1384, + " inv": 1385, + " accept": 1386, + "under": 1387, + ";\n": 1388, + " surpr": 1389, + "de": 1390, + " train": 1391, + " hotel": 1392, + " sleep": 1393, + " dr": 1394, + " hold": 1395, + "lock": 1396, + "pura": 1397, + " springs": 1398, + " ......": 1399, + " agreement": 1400, + " Dar": 1401, + " rest": 1402, + "clud": 1403, + "ator": 1404, + "av": 1405, + " orig": 1406, + " origin": 1407, + " el": 1408, + " nor": 1409, + " pres": 1410, + " understand": 1411, + " taken": 1412, + " light": 1413, + "ener": 1414, + "some": 1415, + " brought": 1416, + "raph": 1417, + " most": 1418, + "oke": 1419, + "-w": 1420, + " unt": 1421, + " father": 1422, + " used": 1423, + " eat": 1424, + " years": 1425, + " While": 1426, + " chan": 1427, + " sudd": 1428, + " sudden": 1429, + " apolog": 1430, + " sett": 1431, + " thin": 1432, + " My": 1433, + " ten": 1434, + "imes": 1435, + "for": 1436, + "oud": 1437, + "When": 1438, + " det": 1439, + " live": 1440, + " oc": 1441, + " five": 1442, + " cont": 1443, + " help": 1444, + " wa": 1445, + " passed": 1446, + " run": 1447, + " making": 1448, + " strange": 1449, + " taking": 1450, + " each": 1451, + "\"You": 1452, + " another": 1453, + "\"Say": 1454, + "\"The": 1455, + "ates": 1456, + " pleas": 1457, + "asshoppers": 1458, + " mom": 1459, + " moment": 1460, + "entle": 1461, + "nglish": 1462, + "CHA": 1463, + " original": 1464, + "ions": 1465, + "uring": 1466, + " public": 1467, + "uct": 1468, + "uck": 1469, + " question": 1470, + "ai": 1471, + "cy": 1472, + "ek": 1473, + " floor": 1474, + " car": 1475, + "ouse": 1476, + " side": 1477, + "-ya": 1478, + " certain": 1479, + "hys": 1480, + "-d": 1481, + "igh": 1482, + "agin": 1483, + "weet": 1484, + " poor": 1485, + " decid": 1486, + "ually": 1487, + " business": 1488, + "pro": 1489, + "plain": 1490, + " stop": 1491, + "!\n": 1492, + " How": 1493, + "\"What": 1494, + "can": 1495, + " Un": 1496, + "ps": 1497, + "und": 1498, + "-night": 1499, + " meeting": 1500, + "edo": 1501, + " raise": 1502, + "Gutenberg": 1503, + " Darling": 1504, + "ume": 1505, + " English": 1506, + "TER": 1507, + "ading": 1508, + " transl": 1509, + " able": 1510, + "ssible": 1511, + " satisf": 1512, + " wanted": 1513, + " sub": 1514, + " case": 1515, + "ific": 1516, + "iterary": 1517, + " maid": 1518, + " inc": 1519, + " pos": 1520, + " position": 1521, + " pat": 1522, + "ured": 1523, + "orry": 1524, + " account": 1525, + " both": 1526, + " frie": 1527, + " friend": 1528, + "this": 1529, + " always": 1530, + " particul": 1531, + "What": 1532, + " small": 1533, + "enty": 1534, + "ushed": 1535, + " mis": 1536, + "ully": 1537, + " recei": 1538, + "You": 1539, + " yet": 1540, + " gave": 1541, + "But": 1542, + "had": 1543, + " answer": 1544, + " abs": 1545, + "ile": 1546, + "cket": 1547, + " nood": 1548, + " course": 1549, + " form": 1550, + " everything": 1551, + "ection": 1552, + "If": 1553, + "part": 1554, + " sing": 1555, + " sit": 1556, + " pur": 1557, + "ip": 1558, + " fishing": 1559, + " eh": 1560, + " par": 1561, + " together": 1562, + "He": 1563, + " whe": 1564, + " whether": 1565, + " bra": 1566, + "\"Yes": 1567, + " punish": 1568, + "Shirt": 1569, + " Yedo": 1570, + " farew": 1571, + " farewell": 1572, + " dance": 1573, + " less": 1574, + "ural": 1575, + " def": 1576, + " attempt": 1577, + "ween": 1578, + " sign": 1579, + " sy": 1580, + "ferent": 1581, + " least": 1582, + "ser": 1583, + "ob": 1584, + "nding": 1585, + " sorry": 1586, + " jumped": 1587, + " jan": 1588, + " janitor": 1589, + "ized": 1590, + " toward": 1591, + " mor": 1592, + "aving": 1593, + " bit": 1594, + "\"This": 1595, + " remark": 1596, + " fut": 1597, + " wonder": 1598, + " fun": 1599, + "Then": 1600, + " dec": 1601, + " whom": 1602, + " didn": 1603, + " rec": 1604, + "bec": 1605, + "\"If": 1606, + " knew": 1607, + "after": 1608, + " thus": 1609, + " isn": 1610, + " sight": 1611, + "med": 1612, + "[F": 1613, + "uss": 1614, + "cident": 1615, + "them": 1616, + " fif": 1617, + " draw": 1618, + " hear": 1619, + " writing": 1620, + " getting": 1621, + "sh": 1622, + "ference": 1623, + " raised": 1624, + "they": 1625, + "ax": 1626, + " fine": 1627, + "sel": 1628, + " Nobe": 1629, + " Nobeok": 1630, + " Nobeoka": 1631, + "ormal": 1632, + " eB": 1633, + "icense": 1634, + "00": 1635, + " best": 1636, + "wor": 1637, + "fic": 1638, + "terest": 1639, + " remar": 1640, + "bl": 1641, + "arted": 1642, + " dark": 1643, + " young": 1644, + "ush": 1645, + " bet": 1646, + "outh": 1647, + "house": 1648, + "aught": 1649, + " phys": 1650, + " strong": 1651, + " fur": 1652, + " roll": 1653, + "cove": 1654, + "chief": 1655, + "awa": 1656, + " followed": 1657, + " fond": 1658, + " future": 1659, + "ird": 1660, + "fully": 1661, + " effort": 1662, + "After": 1663, + "oward": 1664, + " really": 1665, + " among": 1666, + " around": 1667, + " compl": 1668, + " gaz": 1669, + " bow": 1670, + "ater": 1671, + " insist": 1672, + " turned": 1673, + "hel": 1674, + "rem": 1675, + " hours": 1676, + " decided": 1677, + "ys": 1678, + " month": 1679, + "-a": 1680, + " adv": 1681, + " believe": 1682, + " teaching": 1683, + " easy": 1684, + " direction": 1685, + "ooked": 1686, + " war": 1687, + " unless": 1688, + "have": 1689, + " square": 1690, + "vil": 1691, + " quiet": 1692, + " hung": 1693, + " goes": 1694, + " paid": 1695, + " shall": 1696, + "\"No": 1697, + " punishment": 1698, + "pose": 1699, + " sweet": 1700, + "'ve": 1701, + "\"Well": 1702, + " gentle": 1703, + " normal": 1704, + "agraph": 1705, + "chive": 1706, + "chan": 1707, + " includ": 1708, + "ww": 1709, + "org": 1710, + "tem": 1711, + "AR": 1712, + " TH": 1713, + " equ": 1714, + " tone": 1715, + " possible": 1716, + " becom": 1717, + " Japanese": 1718, + "vers": 1719, + " following": 1720, + " pain": 1721, + " whole": 1722, + "wr": 1723, + " serious": 1724, + " nar": 1725, + " tired": 1726, + "In": 1727, + " play": 1728, + " prom": 1729, + " game": 1730, + " Some": 1731, + " happened": 1732, + " cut": 1733, + " twenty": 1734, + " door": 1735, + " morning": 1736, + "hind": 1737, + " bre": 1738, + " inside": 1739, + "ove": 1740, + "alth": 1741, + "uk": 1742, + "arge": 1743, + "amb": 1744, + " dam": 1745, + " worry": 1746, + "ative": 1747, + " expected": 1748, + " fam": 1749, + " pra": 1750, + " pocket": 1751, + "ooks": 1752, + "ched": 1753, + " sil": 1754, + "ol": 1755, + " fav": 1756, + " else": 1757, + " high": 1758, + " real": 1759, + " along": 1760, + " med": 1761, + "hik": 1762, + "hemat": 1763, + "hematics": 1764, + " list": 1765, + " sick": 1766, + "oint": 1767, + "[Foot": 1768, + "[Footnot": 1769, + "[Footnote": 1770, + ".]\n": 1771, + "night": 1772, + "ses": 1773, + "ior": 1774, + " says": 1775, + " mouth": 1776, + "how": 1777, + "ming": 1778, + " clo": 1779, + " cur": 1780, + "ging": 1781, + " suddenly": 1782, + "-ah": 1783, + "amp": 1784, + " black": 1785, + "ross": 1786, + " fac": 1787, + "selves": 1788, + "iew": 1789, + "ission": 1790, + " copyright": 1791, + " paragraph": 1792, + " Archive": 1793, + " donations": 1794, + "Project": 1795, + " cost": 1796, + ".org": 1797, + "LI": 1798, + "uced": 1799, + " suc": 1800, + "yle": 1801, + " force": 1802, + "joy": 1803, + "ouch": 1804, + "tr": 1805, + "It": 1806, + " trad": 1807, + " present": 1808, + " ext": 1809, + "ased": 1810, + "redit": 1811, + " fault": 1812, + "ib": 1813, + "-m": 1814, + "urd": 1815, + " tried": 1816, + "time": 1817, + " pret": 1818, + " spee": 1819, + "ower": 1820, + " words": 1821, + "CHAP": 1822, + "CHAPTER": 1823, + "school": 1824, + " ask": 1825, + " doing": 1826, + "ately": 1827, + " until": 1828, + "bout": 1829, + " tree": 1830, + "call": 1831, + "amash": 1832, + "amashir": 1833, + "amashiro": 1834, + "ste": 1835, + " behind": 1836, + "old": 1837, + " wall": 1838, + "itory": 1839, + " rolled": 1840, + " move": 1841, + " apologize": 1842, + " large": 1843, + "amboo": 1844, + "su": 1845, + " settled": 1846, + "\"He": 1847, + "wo": 1848, + " thinking": 1849, + "used": 1850, + "ified": 1851, + " almost": 1852, + " tre": 1853, + " treat": 1854, + " noodle": 1855, + " note": 1856, + " All": 1857, + " beat": 1858, + " object": 1859, + " seems": 1860, + " ide": 1861, + "Yes": 1862, + "ows": 1863, + " remain": 1864, + " begin": 1865, + "ught": 1866, + "ments": 1867, + " alone": 1868, + "spect": 1869, + " mathematics": 1870, + " rough": 1871, + " outside": 1872, + " comes": 1873, + "back": 1874, + " wind": 1875, + "sed": 1876, + " wouldn": 1877, + "eer": 1878, + "inut": 1879, + "from": 1880, + " repl": 1881, + " narrow": 1882, + " incident": 1883, + " air": 1884, + " sea": 1885, + "ts": 1886, + " surprised": 1887, + " tea": 1888, + "Red": 1889, + " talking": 1890, + " boss": 1891, + "que": 1892, + " pict": 1893, + "irty": 1894, + " ce": 1895, + " lim": 1896, + " Why": 1897, + " point": 1898, + " law": 1899, + "ciated": 1900, + " moon": 1901, + "ircu": 1902, + "got": 1903, + " Is": 1904, + " hands": 1905, + " honor": 1906, + "aut": 1907, + "rge": 1908, + " state": 1909, + " Literary": 1910, + ".F": 1911, + "This": 1912, + "line": 1913, + ".g": 1914, + ".gutenberg": 1915, + " OF": 1916, + "EN": 1917, + "racter": 1918, + " bene": 1919, + " Even": 1920, + "oub": 1921, + " makes": 1922, + " interest": 1923, + "ope": 1924, + "ms": 1925, + " respons": 1926, + " fore": 1927, + " somewhat": 1928, + " honest": 1929, + "ock": 1930, + "irit": 1931, + " held": 1932, + " added": 1933, + "fu": 1934, + "aded": 1935, + "als": 1936, + "att": 1937, + "tern": 1938, + " personal": 1939, + " ass": 1940, + " With": 1941, + "tic": 1942, + "Tokyo": 1943, + " shout": 1944, + " pretty": 1945, + "umb": 1946, + " early": 1947, + "opped": 1948, + " further": 1949, + " fre": 1950, + "esides": 1951, + " bamboo": 1952, + " ir": 1953, + "more": 1954, + " living": 1955, + " received": 1956, + " lived": 1957, + " meant": 1958, + " coward": 1959, + "position": 1960, + " loc": 1961, + "iled": 1962, + " tender": 1963, + " ch": 1964, + " After": 1965, + "cer": 1966, + " favor": 1967, + "who": 1968, + " liked": 1969, + "rance": 1970, + " pri": 1971, + "kisha": 1972, + " study": 1973, + " order": 1974, + " afterward": 1975, + " greatly": 1976, + " unable": 1977, + "go": 1978, + " wait": 1979, + "eping": 1980, + "iding": 1981, + " forty": 1982, + " sky": 1983, + " office": 1984, + "will": 1985, + "\"D": 1986, + "wel": 1987, + " station": 1988, + "bo": 1989, + "hot": 1990, + "such": 1991, + " loud": 1992, + " aw": 1993, + "land": 1994, + "?\n": 1995, + " respect": 1996, + "ances": 1997, + "<|image|>": 1998, + "<|begin_of_image|>": 1999, + "<|end_of_image|>": 2000, + "<|pad|>": 2001 + }, + "merges": [ + ] + } +} diff --git a/torchtitan/experiments/vlm/assets/tokenizer/tokenizer_config.json b/torchtitan/experiments/vlm/assets/tokenizer/tokenizer_config.json new file mode 100644 index 000000000..c2aad203a --- /dev/null +++ b/torchtitan/experiments/vlm/assets/tokenizer/tokenizer_config.json @@ -0,0 +1,65 @@ +{ + "added_tokens_decoder": { + "128000": { + "content": "<|begin_of_text|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "128001": { + "content": "<|end_of_text|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "1998": { + "content": "<|image|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "1999": { + "content": "<|begin_of_image|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "2000": { + "content": "<|end_of_image|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "2001": { + "content": "<|pad|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + } + }, + "bos_token": "<|begin_of_text|>", + "clean_up_tokenization_spaces": true, + "eos_token": "<|end_of_text|>", + "img_token": "<|image|>", + "boi_token": "<|begin_of_image|>", + "eoi_token": "<|end_of_image|>", + "pad_token": "<|pad|>", + "model_input_names": [ + "input_ids", + "attention_mask" + ], + "model_max_length": 131072, + "tokenizer_class": "PreTrainedTokenizerFast" +} diff --git a/torchtitan/experiments/vlm/datasets/image_utils.py b/torchtitan/experiments/vlm/datasets/image_utils.py new file mode 100644 index 000000000..7c70e3d87 --- /dev/null +++ b/torchtitan/experiments/vlm/datasets/image_utils.py @@ -0,0 +1,295 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Utility functions for image processing in multimodal datasets.""" + +import math +from io import BytesIO + +import einops as E +import numpy as np +import requests +import torch + +from PIL import Image + +from torchtitan.tools.logging import logger + + +def process_image( + image: str | bytes | Image.Image, + patch_size: int = 16, + merge_size: int = 1, + max_patch_per_image: int = 256, + min_patch_per_image: int = 1, +) -> torch.Tensor | None: + """Process a single image into normalized tensor format. + + Args: + image: PIL Image, bytes, or URL string + patch_size: Size of each patch + merge_size: Spatial Merge size factor + max_patch_per_image: Maximum patches allowed per image + min_dimension: Minimum dimension for width/height + + Returns: + Tensor of shape (1, H, W, 3) or None if processing fails + + Note: + - Resizes image while maintaining aspect ratio + - Normalizes using CLIP mean/std values + - Returns None if any processing step fails + """ + try: + # Convert various input formats to PIL Image + if isinstance(image, str) and image.startswith("http"): + response = requests.get(image, timeout=10) + image = Image.open(BytesIO(response.content)) + elif isinstance(image, bytes): + image = Image.open(BytesIO(image)) + elif isinstance(image, str): + image = Image.open(image) + + if image.mode != "RGB": + image = image.convert("RGB") + + # Resize maintaining aspect ratio + image = resize_image_by_patch_count( + image, + max_patch_per_image=max_patch_per_image, + patch_size=patch_size, + merge_size=merge_size, + min_patch_per_image=min_patch_per_image, + ) + + # Convert to numpy and normalize + img_array = np.array(image) + img_array = img_array / 255.0 + + # CLIP normalization + mean = np.array([0.48145466, 0.4578275, 0.40821073]) + std = np.array([0.26862954, 0.26130258, 0.27577711]) + img_array = (img_array - mean) / std + + # Convert to tensor (1, H, W, 3) with dummy temporal dim + return torch.from_numpy(img_array).float().unsqueeze(0) + + except Exception as e: + logger.warning(f"Error processing image: {e}") + return None + + +def smart_resize( + height: int, + width: int, + factor: int, # should be equal patch_size * merge_size + max_patch_per_image: int, + min_patch_per_image: int = 1, +): + """Calculate dimensions that maintain aspect ratio and satisfy constraints.""" + if height < factor or width < factor: + raise ValueError( + f"height:{height} or width:{width} must be larger than factor:{factor}" + ) + elif max(height, width) / min(height, width) > 200: + raise ValueError( + f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}" + ) + + h_bar = round(height / factor) * factor + w_bar = round(width / factor) * factor + + # Calculate patch count from adjusted dimensions + current_patches = (h_bar * w_bar) // (factor * factor) + + if current_patches > max_patch_per_image: + # Scale down to fit within max patch limit + max_area = max_patch_per_image * (factor * factor) + beta = math.sqrt((h_bar * w_bar) / max_area) + h_bar = math.floor(height / beta / factor) * factor + w_bar = math.floor(width / beta / factor) * factor + elif current_patches < min_patch_per_image: + beta = math.sqrt(min_patch_per_image / current_patches) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor + + return h_bar, w_bar + + +def resize_image_by_patch_count( + image: Image.Image, + max_patch_per_image: int, + patch_size: int, + merge_size: int, + min_patch_per_image: int = 1, +) -> Image.Image: + """Resize image while maintaining aspect ratio and ensuring patch count is within [min_patch_per_image, max_patch_per_image].""" + original_width, original_height = image.size + factor = patch_size * merge_size + + # Calculate current number of patches + current_patches = (original_height * original_width) // (factor * factor) + + # If patches < min_patch_per_image, scale up proportionally + if current_patches < min_patch_per_image: + if current_patches == 0: + # Special case: image too small to produce any patches + # Scale to minimum viable size (at least factor x factor) + scale_factor = max(factor / original_width, factor / original_height) + else: + scale_factor = math.sqrt(min_patch_per_image / current_patches) + + new_width = int(original_width * scale_factor) + new_height = int(original_height * scale_factor) + + resized_height, resized_width = smart_resize( + new_height, + new_width, + factor, + max_patch_per_image, + ) + return image.resize((resized_width, resized_height)) + + # If patches are within [min, max] range, just use smart_resize + elif current_patches <= max_patch_per_image: + resized_height, resized_width = smart_resize( + original_height, original_width, factor, max_patch_per_image + ) + return image.resize((resized_width, resized_height)) + + # If patches > max_patch_per_image, scale down proportionally + else: + scale_factor = math.sqrt(max_patch_per_image / current_patches) + new_width = int(original_width * scale_factor) + new_height = int(original_height * scale_factor) + + resized_height, resized_width = smart_resize( + new_height, new_width, factor, max_patch_per_image + ) + return image.resize((resized_width, resized_height)) + + +def calculate_image_tokens( + image: Image.Image | torch.Tensor, + patch_size: int, + spatial_merge_size: int, +) -> tuple[int, int, int]: + """Calculate number of tokens needed for an image.""" + if isinstance(image, torch.Tensor): + height, width = image.shape[1:3] + else: + width, height = image.size + + tokens_per_row = int(width / (patch_size * spatial_merge_size)) + num_rows = int(height / (patch_size * spatial_merge_size)) + total_tokens = tokens_per_row * num_rows + + return total_tokens, tokens_per_row, num_rows + + +def convert_to_patches( + pixel_values: torch.Tensor, + patch_size: int, + temporal_patch_size: int = 1, +) -> tuple[torch.Tensor, torch.Tensor]: + """Convert single image tensor to patches and generate coordinate grids. + + Args: + pixel_values: Tensor of shape (T, H, W, C) + patch_size: Spatial patch size (height and width) + temporal_patch_size: Temporal patch size (default=1 for no temporal patching) + + Returns: + patches: Tensor of shape (L, D) where: + L = (T//temporal_patch_size) * (H//patch_size) * (W//patch_size) + D = temporal_patch_size * patch_size * patch_size * C + grid: Tensor of shape (L, 3) containing (t, h, w) coordinates + + Example: + >>> x = torch.randn(4, 224, 224, 3) # Single image with 4 frames + >>> patches, grid = convert_to_patches(x, patch_size=14, temporal_patch_size=2) + >>> print(patches.shape) # (512, 1176) # 512 patches, each 1176-dim + >>> print(grid.shape) # (512, 3) # (t,h,w) coordinates + """ + T, H, W, C = pixel_values.shape + ps = patch_size + ts = temporal_patch_size + device = pixel_values.device + + # Ensure dimensions are divisible + if T % ts != 0: + raise ValueError( + f"Temporal dimension {T} must be divisible by temporal_patch_size {ts}" + ) + if H % ps != 0 or W % ps != 0: + raise ValueError( + f"Spatial dimensions {H},{W} must be divisible by patch_size {ps}" + ) + + patches = E.rearrange( + pixel_values, + "(t pt) (h ph) (w pw) c -> (t h w) (pt ph pw c)", + pt=ts, + ph=ps, + pw=ps, + ) + + # Generate coordinate grid + coords = torch.meshgrid( + torch.arange(T // ts, device=device), + torch.arange(H // ps, device=device), + torch.arange(W // ps, device=device), + indexing="ij", + ) + grid = E.rearrange(torch.stack(coords), "coords t h w -> (t h w) coords") # (L, 3) + + return patches, grid + + +def pad_patches( + patches: torch.Tensor, # Shape L,D + grids: torch.Tensor, # Shape L,3(thw) + max_patches: int, +) -> tuple[torch.Tensor | None, torch.Tensor | None]: + """Pad or truncate patches and grids to max_patches length for single image.""" + L, D = patches.shape + + if L == max_patches: + return patches, grids + elif L < max_patches: + # Pad + pad_len = max_patches - L + zero_patches = torch.zeros(pad_len, D, device=patches.device) + invalid_grids = torch.full((pad_len, 3), -1, device=grids.device) + return ( + torch.cat([patches, zero_patches], 0), + torch.cat([grids, invalid_grids], 0), + ) + else: + # Truncate + logger.error( + f"Truncating Image Patches from {L} to {max_patches} should not happen." + ) + return None, None + + +def pad_empty_images_to_target_batch_size( + patches: torch.Tensor, + grids: torch.Tensor, + max_images: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Pad vision encoder batch with blank images if needed.""" + N, L, D = patches.shape + if N >= max_images: + return patches, grids + + blank_count = max_images - N + blank_patches = torch.zeros(blank_count, L, D, device=patches.device) + blank_grids = torch.full((blank_count, L, 3), -1, device=grids.device) + return ( + torch.cat([patches, blank_patches], dim=0), + torch.cat([grids, blank_grids], dim=0), + ) diff --git a/torchtitan/experiments/vlm/datasets/mm_collator_nld.py b/torchtitan/experiments/vlm/datasets/mm_collator_nld.py new file mode 100644 index 000000000..777d726e5 --- /dev/null +++ b/torchtitan/experiments/vlm/datasets/mm_collator_nld.py @@ -0,0 +1,221 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import Any + +import torch +from torch.nn.utils.rnn import pad_sequence + +from torchtitan.tools.logging import logger + +from .image_utils import ( + convert_to_patches, + pad_empty_images_to_target_batch_size, + pad_patches, +) +from .text_utils import pad_input_ids_and_labels_to_target_batch_size, pad_text_batch + + +@dataclass +class MultiModalCollatorNLD: + """Multimodal collator that works with image patches in NLD format. + N: Number of images (vision encoder's batch size) + L: Length of patches (vision encoder's sequence length) + D: Dimension of a patch (3 * spatial_patch_size**2 * temporral patch_size) + + This module provides a collator class that handles both image and text data, + converting images to patches and preparing text for model input. + + Example: + >>> # Initialize collator + >>> collator = MultiModalCollatorNLD( + ... batch_size=2, + ... seq_len=32, + ... max_images_per_batch=4, + ... max_patch_per_image=6, + ... patch_size=16, + ... padding_idx=0, + ... ) + >>> + >>> # Create sample batch + >>> batch = [ + ... { + ... "input_ids": torch.tensor([1, 2, 3]), + ... "labels": torch.tensor([2, 3, 4]), + ... "pixel_values": [ + ... torch.randn(1, 32, 32, 3), + ... torch.randn(1, 32, 48, 3) + ... ] + ... }, + ... { + ... "input_ids": torch.tensor([5, 6]), + ... "labels": torch.tensor([6, 7]), + ... "pixel_values": [ + ... torch.randn(1, 32, 32, 3) # One image + ... ] + ... } + ... ] + >>> + >>> # Collate batch + >>> outputs = collator(batch) + >>> + >>> # Examine outputs + >>> print(outputs["input_ids"].shape) # (2, 32) - Padded to seq_len + >>> print(outputs["labels"].shape) # (2, 32) - Padded to seq_len + >>> print(outputs["pixel_values"].shape) # (4, 6, 768) - (N=4 images, L=6 patches, D=16*16*3) + >>> print(outputs["grid_thw"].shape) # (4, 6, 3) - Coordinates for each patch + >>> + >>> # The collated batch has: + >>> # 1. Text tensors padded to max length + >>> # 2. Images converted to patches in NLD format + >>> # 3. Grid coordinates for each patch + >>> # 4. All tensors properly batched and padded + """ + + batch_size: int # LLM's batch size + seq_len: int # LLM's maximum sequence length + + patch_size: int # Patch size for converting images to patches + max_images_per_batch: int # Vision Encoder's batch size + max_patches_per_image: int # Vision Encoder's sequence length + + padding_idx: int = 0 + ignore_idx: int = -100 + + def process_images( + self, all_images: list[torch.Tensor] + ) -> tuple[torch.Tensor | None, torch.Tensor | None]: + """Process a list of image tensors into patches with coordinate grids. + + Args: + all_images: list of image tensors, each of shape (T, H, W, 3) + + Returns: + patches: Tensor of shape (N, L, D) or None if no images + grids: Tensor of shape (N, L, 3) or None if no images + """ + if not all_images: + return None, None + + patch_list, grid_list = [], [] + for img in all_images: + # Convert single image to patches + patches, grids = convert_to_patches(img, patch_size=self.patch_size) + + # Pad/truncate to max patches + patches, grids = pad_patches(patches, grids, self.max_patches_per_image) + + patch_list.append(patches) + grid_list.append(grids) + + # Stack all images + patches = torch.stack(patch_list) + grids = torch.stack(grid_list) + + # Pad to max_images_per_batch with empty images + patches, grids = pad_empty_images_to_target_batch_size( + patches, grids, self.max_images_per_batch + ) + + return patches, grids + + def process_text( + self, + batch: list[dict[str, Any]], + ) -> tuple[torch.Tensor, torch.Tensor]: + """Process text inputs and labels from batch. + + Args: + batch: list of dictionaries containing "input_ids" and "labels" + + Returns: + input_ids: Tensor of shape (B, L) + labels: Tensor of shape (B, L) + + Note: + B = batch size (padded if needed) + L = sequence length (padded/truncated to seq_len) + """ + # Pad sequences in batch + input_ids = pad_sequence( + [s["input_ids"] for s in batch], + batch_first=True, + padding_value=self.padding_idx, + ) + labels = pad_sequence( + [s["labels"] for s in batch], + batch_first=True, + padding_value=self.padding_idx, + ) + + # Handle sequence length + input_ids, labels = pad_text_batch( + input_ids, + labels, + self.seq_len + 1, # Extra token for label shifting + padding_idx=self.padding_idx, + ignore_idx=self.ignore_idx, + ) + input_ids, labels = pad_input_ids_and_labels_to_target_batch_size( + input_ids, + labels, + self.batch_size, + padding_idx=self.padding_idx, + ignore_idx=self.ignore_idx, + ) + + return input_ids[:, :-1], labels[:, 1:] # Shift for next token prediction + + def __call__( + self, batch: list[dict[str, Any]] + ) -> tuple[dict[str, torch.Tensor | None], torch.Tensor]: + """Encode batch with patch-based approach. + + Args: + batch: list of dictionaries containing: + - input_ids: Tensor of shape (S) + - labels: Tensor of shape (L) + - pixel_values: list of tensors, each (1, H, W, 3) + + Returns: + Dictionary containing: + - input_ids: Tensor of shape (B, L) + - labels: Tensor of shape (B, L) + - pixel_values: Tensor of shape (N, L, D) + - grid_thw: Tensor of shape (N, L, 3) + """ + # Count images per sample and total images + images_per_sample = [] + for sample in batch: + num_images = len(sample.get("pixel_values", [])) + images_per_sample.append(num_images) + + # Remove samples from end until total images <= max_images_per_batch + total_images = sum(images_per_sample) + while total_images > self.max_images_per_batch and batch: + removed_images = images_per_sample.pop() + total_images -= removed_images + batch.pop() + logger.warning( + f"Removed sample with {removed_images} images to keep " + f"total images <= {self.max_images_per_batch}" + ) + + # Process all images in batch + all_images = [ + img + for sample in batch + if "pixel_values" in sample + for img in sample["pixel_values"] + ] + patches, grids = self.process_images(all_images) + + # Process text and pad to batch size + input_ids, labels = self.process_text(batch) + input_dict = {"input": input_ids, "pixel_values": patches, "grid_thw": grids} + + return input_dict, labels diff --git a/torchtitan/experiments/vlm/datasets/mm_datasets.py b/torchtitan/experiments/vlm/datasets/mm_datasets.py new file mode 100644 index 000000000..3c3810c65 --- /dev/null +++ b/torchtitan/experiments/vlm/datasets/mm_datasets.py @@ -0,0 +1,442 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Multimodal dataset implementation for vision-language models training. + +This module provides dataset classes for handling multimodal data +including images and text. Images are interleaved with text at native aspect ratio and resolution. +It supports both streaming and non-streaming datasets from HuggingFace. +""" + +from dataclasses import dataclass +from typing import Any, Callable + +import torch +from datasets import Dataset, load_dataset +from datasets.distributed import split_dataset_by_node +from torch.distributed.checkpoint.stateful import Stateful +from torch.utils.data import IterableDataset + +from torchtitan.components.dataloader import ParallelAwareDataloader +from torchtitan.components.tokenizer import BaseTokenizer, HuggingFaceTokenizer +from torchtitan.config import JobConfig +from torchtitan.tools.logging import logger + +from .image_utils import calculate_image_tokens, process_image +from .mm_collator_nld import MultiModalCollatorNLD +from .packing_utils import SamplePacker +from .text_utils import process_text_with_images + +IGNORE_INDEX = -100 # Pytorch F.cross_entropy default + + +def _process_mm_sample( + texts: list[str] | str, + images: list[bytes] | bytes, + tokenizer: BaseTokenizer, + patch_size: int, + max_patch_per_image: int, + spatial_merge_size: int, + special_tokens, +) -> dict[str, Any] | None: + """Common processing logic for multimodal samples. + + Args: + texts: List of strings with None indicating image positions + images: List of image bytes with None for text positions + tokenizer: Tokenizer for text processing + patch_size: Size of image patches + max_patch_per_image: Maximum patches per image + spatial_merge_size: merge 2D image patches to reduce LLM's sequence length. + - if 1 (default): no merge, effectively NoOp + - if 2: 2x2=4 image patches will be reduced to 1 LLM sequence + + Returns: + Dict with: + - input_ids: Tensor of token IDs + - labels: Tensor of label IDs + - pixel_values: List of processed image tensors + + Example: + Interleaved format: + texts = [text1, None, text2, None, text3] + images = [None, img1, None, img2, None] + + Image-text pair format as a special case of interleaved: + texts = [None, text] + images = [image, None] + """ + try: + # Normalize inputs to lists + texts = [texts] if isinstance(texts, str) else texts + images = [images] if isinstance(images, bytes) else images + + if not texts or len(texts) != len(images): + return None + + # Process all images first + processed_images = [] + image_dimensions = [] + texts_list = list(texts) # Make mutable copy + + for idx, img in enumerate(images): + if img is not None: + processed_img = process_image( + img, + patch_size=patch_size, + merge_size=spatial_merge_size, + max_patch_per_image=max_patch_per_image, + ) + if processed_img is not None: + num_tokens, width, height = calculate_image_tokens( + processed_img, + patch_size=patch_size, + spatial_merge_size=spatial_merge_size, + ) + processed_images.append(processed_img) + image_dimensions.append((num_tokens, width, height)) + # Replace None with image token + texts_list[idx] = special_tokens.img_token + else: + # Replace None with empty string if processing failed + texts_list[idx] = "" + + if not processed_images: + return None + + # Process all image tokens at once + processed_text = process_text_with_images( + texts_list, image_dimensions, tokenizer, special_tokens, add_eos=True + ) + + tokens = tokenizer.encode(processed_text) + + # Convert to tensors + input_ids = torch.tensor(tokens) + labels = torch.tensor(tokens) + + # Mask special tokens in labels + special_token_ids = torch.tensor( + [special_tokens.boi_id, special_tokens.eoi_id, special_tokens.img_id] + ) + labels = torch.where( + torch.isin(labels, special_token_ids), IGNORE_INDEX, labels + ) + + return { + "input_ids": input_ids, + "labels": labels, + "pixel_values": processed_images, + } + + except Exception as e: + logger.warning(f"Error processing sample: {e}") + return None + + +def _process_obelics_sample( + sample: dict[str, Any], + tokenizer: HuggingFaceTokenizer, + patch_size: int, + spatial_merge_size: int, + max_patch_per_image: int, + special_tokens, +) -> dict[str, Any] | None: + """Process a sample from the OBELICS dataset.""" + return _process_mm_sample( + texts=sample.get("texts", []), + images=sample.get("images", []), + tokenizer=tokenizer, + patch_size=patch_size, + spatial_merge_size=spatial_merge_size, + max_patch_per_image=max_patch_per_image, + special_tokens=special_tokens, + ) + + +def _process_cc12_wd_sample( + sample: dict[str, Any], + tokenizer: BaseTokenizer, + patch_size: int, + spatial_merge_size: int, + max_patch_per_image: int, + special_tokens, +) -> dict[str, Any] | None: + """Process a sample from the CC12-WD dataset. + Transforms CC12-WD format to match Interleaved format: + - texts: [None, text] to indicate image position + - images: [image, None] to match text position + """ + text = sample.get("txt", "") + image = sample.get("jpg", None) + + # Transform to OBELICS format + texts = [None, text] + images = [image, None] + + return _process_mm_sample( + texts=texts, + images=images, + tokenizer=tokenizer, + patch_size=patch_size, + spatial_merge_size=spatial_merge_size, + max_patch_per_image=max_patch_per_image, + special_tokens=special_tokens, + ) + + +@dataclass +class MMDatasetConfig: + path: str + loader: Callable + sample_processor: Callable + + +MM_DATASETS = { + "obelics": MMDatasetConfig( + path="HuggingFaceM4/OBELICS", + loader=lambda path: load_dataset(path, split="train", streaming=True), + sample_processor=_process_obelics_sample, + ), + "cc12m": MMDatasetConfig( + path="pixparse/cc12m-wds", + loader=lambda path: load_dataset(path, split="train", streaming=True), + sample_processor=_process_cc12_wd_sample, + ), +} + + +def _validate_mm_dataset( + dataset_name: str, dataset_path: str | None = None +) -> tuple[str, Callable, Callable]: + """Validate dataset name and path.""" + if dataset_name not in MM_DATASETS: + raise ValueError( + f"Dataset {dataset_name} is not supported. " + f"Supported datasets are: {list(MM_DATASETS.keys())}" + ) + + config = MM_DATASETS[dataset_name] + path = dataset_path or config.path + logger.info(f"Preparing {dataset_name} dataset from {path}") + return path, config.loader, config.sample_processor + + +class MultiModalDataset(IterableDataset, Stateful): + """PyTorch MultiModal Dataset with support for sample packing.""" + + def __init__( + self, + dataset_name: str, + dataset_path: str | None, + tokenizer: BaseTokenizer, + batch_size: int, + seq_len: int, + patch_size: int, + spatial_merge_size: int, + max_patches_per_image: int, + max_images_per_batch: int, + packing_buffer_size: int, + special_tokens, + dp_rank: int = 0, + dp_world_size: int = 1, + infinite: bool = False, + ) -> None: + # Force lowercase for consistent comparison + dataset_name = dataset_name.lower() + + path, dataset_loader, self.sample_processor = _validate_mm_dataset( + dataset_name, dataset_path + ) + ds = dataset_loader(path) + self._data = split_dataset_by_node(ds, dp_rank, dp_world_size) + + self._tokenizer = tokenizer + self.batch_size = batch_size + self.seq_len = seq_len + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.max_patches_per_image = max_patches_per_image + self.max_images_per_batch = max_images_per_batch + self.special_tokens = special_tokens + self.enable_packing = packing_buffer_size > 0 + if self.enable_packing: + self.packer = SamplePacker( + max_seq_length=seq_len, + buffer_size=packing_buffer_size, + batch_size=batch_size, + ) + self.infinite = infinite + self._sample_idx = 0 + + def __iter__(self): + while True: + for sample in self._get_data_iter(): + try: + self._sample_idx += 1 + + processed = self.sample_processor( + sample=sample, + tokenizer=self._tokenizer, + patch_size=self.patch_size, + spatial_merge_size=self.spatial_merge_size, + max_patch_per_image=self.max_patches_per_image, + special_tokens=self.special_tokens, + ) + if processed is None: + continue + + if processed["input_ids"].shape[0] > self.seq_len: + logger.warning( + f"Sample length {processed['input_ids'].shape[0]} > training {self.seq_len=}. Skip" + ) + continue + + if processed is not None: + + if self.enable_packing: + self.packer.add_sample(processed) + + if self.packer.has_batch_ready(): + batch = self.packer.get_next_batch() + if batch: + yield from batch + else: + yield processed # individual sample + + except Exception as e: + logger.warning(f"Error in iteration: {e}") + continue + + if self.enable_packing: + # Handle remaining samples in packer + while True: + batch = self.packer.get_next_batch() + if batch: + yield from batch + else: + break + + if not self.infinite: + break + else: + self._sample_idx = 0 + + def _get_data_iter(self): + try: + if not hasattr(self._data, "iterable_dataset"): + if isinstance(self._data, Dataset) and ( + self._sample_idx == len(self._data) + ): + return iter([]) + + it = iter(self._data) + + if self._sample_idx > 0: + for _ in range(self._sample_idx): + next(it) + + return it + except Exception as e: + logger.error(f"Error in _get_data_iter: {e}") + return iter([]) + + def load_state_dict(self, state_dict): + self._sample_idx = state_dict["sample_idx"] + + # Restore packer state if available + if ( + self.enable_packing + and hasattr(self, "packer") + and "packer_state" in state_dict + ): + packer_state = state_dict["packer_state"] + self.packer.sample_buffer.clear() + self.packer.packed_samples.clear() + self.packer.sample_buffer.extend(packer_state["sample_buffer"]) + self.packer.packed_samples.extend(packer_state["packed_samples"]) + + def state_dict(self): + state = {"sample_idx": self._sample_idx} + + # Save packer state if packing is enabled + if self.enable_packing and hasattr(self, "packer"): + state["packer_state"] = { + "sample_buffer": list(self.packer.sample_buffer), + "packed_samples": list(self.packer.packed_samples), + } + + return state + + +def build_mm_dataloader( + dp_world_size: int, + dp_rank: int, + tokenizer: HuggingFaceTokenizer, + job_config: JobConfig, + infinite: bool = True, +) -> ParallelAwareDataloader: + """Build a data loader for multimodal datasets. + + Args: + dp_world_size: Data parallel world size + dp_rank: Data parallel rank + tokenizer: Tokenizer for text processing + job_config: Job configuration + infinite: Whether to loop infinitely + + Returns: + DataLoader with appropriate parallelism handling + """ + dataset_path = job_config.training.dataset_path + batch_size = job_config.training.local_batch_size + seq_len = job_config.training.seq_len + + max_images_per_batch = job_config.training.max_images_per_batch + max_patches_per_image = job_config.training.max_patches_per_image + packing_buffer_size = job_config.training.packing_buffer_size + spatial_merge_size = job_config.training.spatial_merge_size + + # NOTE: technically patch_size belongs to model variants, but we don't + # have access to model_args here. To discuss later. + patch_size = job_config.training.patch_size + + dataset = MultiModalDataset( + dataset_name=job_config.training.dataset, + dataset_path=dataset_path, + tokenizer=tokenizer, + batch_size=batch_size, + seq_len=seq_len, + patch_size=patch_size, + spatial_merge_size=spatial_merge_size, + max_patches_per_image=max_patches_per_image, + max_images_per_batch=max_images_per_batch, + packing_buffer_size=packing_buffer_size, + special_tokens=job_config.special_tokens, + dp_rank=dp_rank, + dp_world_size=dp_world_size, + infinite=infinite, + ) + + collate_fn = MultiModalCollatorNLD( + batch_size=batch_size, + seq_len=job_config.training.seq_len, + patch_size=patch_size, + max_images_per_batch=max_images_per_batch, + max_patches_per_image=max_patches_per_image, + padding_idx=job_config.special_tokens.pad_id, + ignore_idx=IGNORE_INDEX, + ) + + base_dataloader = ParallelAwareDataloader( + dataset=dataset, + dp_rank=dp_rank, + dp_world_size=dp_world_size, + batch_size=batch_size, + collate_fn=collate_fn, + ) + + return base_dataloader diff --git a/torchtitan/experiments/vlm/datasets/packing_utils.py b/torchtitan/experiments/vlm/datasets/packing_utils.py new file mode 100644 index 000000000..3abda4178 --- /dev/null +++ b/torchtitan/experiments/vlm/datasets/packing_utils.py @@ -0,0 +1,126 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Utilities for efficient sample packing in multimodal datasets.""" + +from collections import deque +from typing import Any + +import torch +from torchtitan.tools.logging import logger + + +class SamplePacker: + """Packs multiple samples together to maximize sequence length utilization.""" + + def __init__( + self, + max_seq_length: int, + buffer_size: int = 100, + batch_size: int = 8, + ): + self.max_seq_length = max_seq_length + self.buffer_size = buffer_size + self.batch_size = batch_size + + # Initialize buffers + self.sample_buffer: deque = deque(maxlen=buffer_size) + self.packed_samples: deque = deque() + + def _pack_buffered_samples(self) -> list[dict[str, Any]]: + """Pack buffered samples into optimal sequences.""" + if not self.sample_buffer: + return [] + + # Sort samples by length for better packing + samples = sorted( + self.sample_buffer, key=lambda x: len(x["input_ids"]), reverse=True + ) + + packed_sequences = [] + current_sequence = [] + current_length = 0 + + for sample in samples: + sample_length = len(sample["input_ids"]) + + # Skip very long samples + if sample_length > self.max_seq_length: + logger.warning( + f"Sample length {sample_length} exceeds max_seq_length " + f"{self.max_seq_length}, will be skipped" + ) + continue + + # Check if adding this sample would exceed max length + if current_sequence and ( + current_length + sample_length > self.max_seq_length + ): + # Current sequence is full, create packed sample + packed_sequences.append( + { + "input_ids": torch.cat( + [s["input_ids"] for s in current_sequence] + ), + "labels": torch.cat([s["labels"] for s in current_sequence]), + "pixel_values": [ + img for s in current_sequence for img in s["pixel_values"] + ], + } + ) + current_sequence = [] + current_length = 0 + + # Add sample to current sequence + current_sequence.append(sample) + current_length += sample_length + + # Handle remaining sequence + if current_sequence: + packed_sequences.append( + { + "input_ids": torch.cat([s["input_ids"] for s in current_sequence]), + "labels": torch.cat([s["labels"] for s in current_sequence]), + "pixel_values": [ + img for s in current_sequence for img in s["pixel_values"] + ], + } + ) + + # Clear buffer + self.sample_buffer.clear() + return packed_sequences + + def add_sample(self, sample: dict[str, Any]) -> None: + """Add a sample to the buffer.""" + self.sample_buffer.append(sample) + + if len(self.sample_buffer) >= self.buffer_size: + packed = self._pack_buffered_samples() + self.packed_samples.extend(packed) + + def has_batch_ready(self) -> bool: + """Check if a full batch is ready.""" + return len(self.packed_samples) >= self.batch_size + + def get_next_batch(self) -> list[dict[str, Any]] | None: + """Get next batch of packed samples if available.""" + if not self.has_batch_ready(): + # Try to pack any remaining samples + if self.sample_buffer: + packed = self._pack_buffered_samples() + self.packed_samples.extend(packed) + + if not self.has_batch_ready(): + return None + + batch = [] + for _ in range(self.batch_size): + if not self.packed_samples: + break + batch.append(self.packed_samples.popleft()) + + return batch diff --git a/torchtitan/experiments/vlm/datasets/text_utils.py b/torchtitan/experiments/vlm/datasets/text_utils.py new file mode 100644 index 000000000..5db57b215 --- /dev/null +++ b/torchtitan/experiments/vlm/datasets/text_utils.py @@ -0,0 +1,141 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + + +def pad_text_batch( + input_ids: torch.Tensor, + labels: torch.Tensor, + seq_len: int, + padding_idx: int = 0, + ignore_idx: int = -100, +) -> tuple[torch.Tensor, torch.Tensor]: + """Pad input_ids and labels to desired sequence length. + + Args: + input_ids: Tensor of shape (B, L) + labels: Tensor of shape (B, L) + seq_len: Desired sequence length + padding_idx: Token ID to use for padding + ignore_idx: Token ID to use for ignored positions in labels + + Returns: + padded_input_ids: Tensor of shape (B, seq_len) + padded_labels: Tensor of shape (B, seq_len) + """ + B, L = input_ids.shape + + if L < seq_len: + # Pad to desired length + padding_length = seq_len - L + padding_input = torch.full( + (B, padding_length), padding_idx, dtype=torch.long, device=input_ids.device + ) + padding_labels = torch.ones( + (B, padding_length), dtype=torch.long, device=labels.device + ) + + input_ids = torch.cat([input_ids, padding_input], dim=1) + labels = torch.cat([labels, padding_labels], dim=1) + + elif L > seq_len: + # Truncate to desired length + input_ids = input_ids[:, :seq_len] + labels = labels[:, :seq_len] + + # Convert padding tokens to ignore_idx in labels + labels[labels == padding_idx] = ignore_idx + + return input_ids, labels + + +def pad_input_ids_and_labels_to_target_batch_size( + input_ids: torch.Tensor, + labels: torch.Tensor, + target_batch_size: int, + padding_idx: int = 0, + ignore_idx: int = -100, +) -> tuple[torch.Tensor, torch.Tensor]: + """Pad batch dimension to target size. + + Args: + input_ids: Tensor of shape (B, L) + labels: Tensor of shape (B, L) + target_batch_size: Desired batch size + padding_idx: Token ID to use for padding + ignore_idx: Token ID to use for ignored positions in labels + + Returns: + padded_input_ids: Tensor of shape (target_batch_size, L) + padded_labels: Tensor of shape (target_batch_size, L) + """ + B, L = input_ids.shape + if B >= target_batch_size: + return input_ids, labels + + padding_needed = target_batch_size - B + padding_input = torch.full( + (padding_needed, L), padding_idx, dtype=torch.long, device=input_ids.device + ) + padding_labels = torch.full( + (padding_needed, L), padding_idx, dtype=torch.long, device=labels.device + ) + + input_ids = torch.cat([input_ids, padding_input], dim=0) + labels = torch.cat([labels, padding_labels], dim=0) + + # Convert padding tokens to ignore_idx in labels + labels[labels == padding_idx] = ignore_idx + + return input_ids, labels + + +def process_text_with_images( + text: list[str], + image_tokens: list[tuple[int, int, int]], # [(total, width, height), ...] + tokenizer, + special_tokens, + add_eos: bool = True, +) -> str: + """Process text by interleaving image tokens efficiently. + + Args: + text: Raw text string + image_tokens: List of (total_tokens, width, height) for each image + tokenizer: Tokenizer with special tokens + add_eos: Whether to add EOS token + + Returns: + Processed text with image tokens inserted + + Example: + >>> text = ["", "photo of a cat"] + >>> image_tokens = [(16, 4, 4)] # 4x4 grid = 16 tokens + >>> result = process_text_with_images(text, image_tokens, tokenizer) + >>> print(result) # <|startofimage|><|image|>...<|endofimage|> A photo... + """ + parts = [] # Build parts list instead of string concat + image_idx = 0 + + for part in text: + if part == special_tokens.img_token and image_idx < len(image_tokens): + num_image_tokens, _, _ = image_tokens[image_idx] + + parts.extend( + [ + special_tokens.boi_token, + *([special_tokens.img_token] * num_image_tokens), + special_tokens.eoi_token, + ] + ) + image_idx += 1 + else: + parts.append(part) + + # Join all parts with spaces and add EOS if needed + result = "".join(parts) + return result.strip() + (tokenizer.eos_token if add_eos else "") diff --git a/torchtitan/experiments/vlm/infra/parallelize.py b/torchtitan/experiments/vlm/infra/parallelize.py new file mode 100644 index 000000000..8477cb1ec --- /dev/null +++ b/torchtitan/experiments/vlm/infra/parallelize.py @@ -0,0 +1,329 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This file applies the PT-D parallelisms (except pipeline parallelism) and various +# training techniques (e.g. activation checkpointing and compile) to the Llama model. + +from collections import defaultdict +from typing import Optional + +import torch +import torch.nn as nn +from torch.distributed._composable.replicate import replicate +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + checkpoint_wrapper as ptd_checkpoint_wrapper, +) + +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy + +from torchtitan.config import JobConfig, TORCH_DTYPE_MAP +from torchtitan.config.job_config import ActivationCheckpoint as ACConfig +from torchtitan.distributed import ParallelDims +from torchtitan.tools.logging import logger + + +def parallelize_vlm( + model: nn.Module, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """ + Apply tensor parallelism, activation checkpointing, torch.compile, and data + parallelism to the model. + + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + """ + assert isinstance(model.encoder, nn.Module) + world_mesh = parallel_dims.world_mesh + # TODO: TP currently cannot handle uneven seq_len because we set + # `use_local_output=True` to use plain Tensors for legacy reasons. + # Need to revisit this. + assert ( + job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 + ), f""" + Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree + ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). + """ + + if ( + job_config.parallelism.context_parallel_degree > 1 + and model.model_args.use_flex_attn + ): + raise NotImplementedError("CP support for FlexAttention is still in progress.") + + if parallel_dims.tp_enabled: + raise NotImplementedError("TP support for VLM training is still in progress.") + + if job_config.activation_checkpoint.mode != "none": + apply_ac(model, job_config.activation_checkpoint) + apply_ac(model.encoder, job_config.activation_checkpoint) + + # turn on per-TransformerBlock compile after AC wrapping and before FSDP + if job_config.compile.enable: + apply_compile(model) + apply_compile(model.encoder) + + if parallel_dims.fsdp_enabled: + # apply FSDP or HSDP, potentially with Context Parallel + if parallel_dims.dp_replicate_enabled: + dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + else: + dp_mesh_dim_names = ("dp_shard_cp",) + + apply_fsdp( + model, + world_mesh[tuple(dp_mesh_dim_names)], + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], + pp_enabled=parallel_dims.pp_enabled, + cpu_offload=job_config.training.enable_cpu_offload, + reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, + ) + + if parallel_dims.dp_replicate_enabled: + logger.info("Applied HSDP to the model") + else: + logger.info("Applied FSDP to the model") + + if parallel_dims.cp_enabled: + logger.info("Applied Context Parallel to the model") + + if job_config.training.enable_cpu_offload: + logger.info("Applied CPU Offloading to the model") + elif parallel_dims.dp_replicate_enabled: + if world_mesh.ndim > 1: + raise RuntimeError("DDP has not supported > 1D parallelism") + apply_ddp( + model, + world_mesh, + enable_compile=job_config.compile.enable, + enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, + ) + + return model + + +# for selective op activation checkpointing +_save_list = { + torch.ops.aten.mm.default, + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops._c10d_functional.reduce_scatter_tensor.default, + # for low precision training, it's useful to always save + # the result of max, since the absolute maximum is + # used to compute the scaling factor for quantization. + torch.ops.aten.max.default, +} + + +def _apply_ac_to_transformer_block( + module: nn.Module, ac_config: ACConfig, *, base_fqn: Optional[str] = None +): + valid_ac_modes = ("full", "selective") + if ac_config.mode not in valid_ac_modes: + raise ValueError( + f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}" + ) + + if ac_config.mode == "full": + return ptd_checkpoint_wrapper(module, preserve_rng_state=False) + + assert ac_config.mode == "selective", f"{ac_config.mode}" + use_op_sac = ac_config.selective_ac_option == "op" + use_layer_sac = ac_config.selective_ac_option.isdigit() + if not use_op_sac and not use_layer_sac: + raise ValueError( + f"Invalid selective AC option: {ac_config.selective_ac_option}. " + f"Valid options: 'op' or a positive int representing layer frequency" + ) + if use_op_sac: + from torch.utils.checkpoint import ( + CheckpointPolicy, + create_selective_checkpoint_contexts, + ) + + mm_recompute_shapes = set() + if len(ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns) > 0: + for module_fqn, submod in module.named_modules(): + fqn = module_fqn + if base_fqn is not None: + fqn = f"{base_fqn}.{module_fqn}" + if not any( + filter_fqn in fqn + for filter_fqn in ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns + ): + continue + if not isinstance(submod, nn.Linear): + raise ValueError( + "per_op_sac_force_recompute_mm_shapes_by_fqns expected to match " + f"a nn.Linear, but got: {submod}" + ) + out_f, in_f = submod.weight.shape + mm_recompute_shapes.add((in_f, out_f)) + logger.debug( + f"Selective op AC force recomputing mms with rhs shapes {mm_recompute_shapes}" + ) + + def _get_custom_policy(meta): + def _custom_policy(ctx, func, *args, **kwargs): + mode = "recompute" if ctx.is_recompute else "forward" + mm_count_key = f"{mode}_mm_count" + if func == torch.ops.aten.mm.default: + if args[1].shape in mm_recompute_shapes: + return CheckpointPolicy.PREFER_RECOMPUTE + meta[mm_count_key] += 1 + # Saves output of all compute ops, except every second mm + to_save = func in _save_list and not ( + func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0 + ) + return ( + CheckpointPolicy.MUST_SAVE + if to_save + else CheckpointPolicy.PREFER_RECOMPUTE + ) + + return _custom_policy + + def selective_checkpointing_context_fn(): + meta = defaultdict(int) + return create_selective_checkpoint_contexts(_get_custom_policy(meta)) + + return ptd_checkpoint_wrapper( + module, + context_fn=selective_checkpointing_context_fn, + preserve_rng_state=False, + ) + elif use_layer_sac: + # Checkpoint every `ac_freq` of the modules passed to this function + ac_freq = int(ac_config.selective_ac_option) + ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0) + ptd_checkpoint_wrapper._count += 1 + if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0: + return ptd_checkpoint_wrapper(module, preserve_rng_state=False) + else: + return module + + +def apply_ac(model: nn.Module, ac_config): + """Apply activation checkpointing to the model.""" + for layer_id, transformer_block in model.layers.named_children(): + transformer_block = _apply_ac_to_transformer_block( + transformer_block, ac_config, base_fqn=f"layers.{layer_id}" + ) + model.layers.register_module(layer_id, transformer_block) + + logger.info( + f"Applied {ac_config.mode} activation checkpointing to the model {type(model).__name__}" + ) + + +def apply_compile(model: nn.Module): + """ + Apply torch.compile to each TransformerBlock, which makes compilation efficient due to + repeated structure. Alternatively one can compile the whole model (after applying DP). + """ + for layer_id, transformer_block in model.layers.named_children(): + transformer_block = torch.compile(transformer_block, fullgraph=True) + model.layers.register_module(layer_id, transformer_block) + + logger.info( + f"Compiling each TransformerBlock of {type(model).__name__} with torch.compile" + ) + + +def apply_fsdp( + model: nn.Module, + dp_mesh: DeviceMesh, + param_dtype: torch.dtype, + reduce_dtype: torch.dtype, + pp_enabled: bool, + cpu_offload: bool = False, + reshard_after_forward_policy: str = "default", +): + """ + Apply data parallelism (via FSDP2) to the model. + + Args: + model (nn.Module): The model to apply data parallelism to. + dp_mesh (DeviceMesh): The device mesh to use for data parallelism. + param_dtype (torch.dtype): The data type to use for model parameters. + reduce_dtype (torch.dtype): The data type to use for reduction operations. + pp_enabled (bool): Whether pipeline parallelism is enabled. + cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False. + reshard_after_forward_policy (str, optional): The policy to use for resharding after forward pass. Defaults to "default". + Other options: "never", "always". + - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios. + - "always" will enable `reshard_after_forward` for all forward passes. + - "never" will disable `reshard_after_forward` for all forward passes. + + """ + mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) + fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} + if cpu_offload: + fsdp_config["offload_policy"] = CPUOffloadPolicy() + + match reshard_after_forward_policy: + case "always": + reshard_after_forward = True + case "never": + reshard_after_forward = False + case "default": + # For PP, by default do not reshard after forward to avoid per-microbatch + # all-gathers, which can be expensive and non-overlapped + reshard_after_forward = not pp_enabled + case _: + raise ValueError( + f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}." + ) + + if model.tok_embeddings is not None: + fully_shard( + model.tok_embeddings, + **fsdp_config, + reshard_after_forward=reshard_after_forward, + ) + for layer_id, transformer_block in model.encoder.layers.items(): + fully_shard( + transformer_block, + **fsdp_config, + reshard_after_forward=reshard_after_forward, + ) + for layer_id, transformer_block in model.layers.items(): + fully_shard( + transformer_block, + **fsdp_config, + reshard_after_forward=reshard_after_forward, + ) + # As an optimization, do not reshard_after_forward the last layers by default + # since FSDP would prefetch them immediately after the forward pass + if model.norm is not None and model.output is not None: + fully_shard( + [model.norm, model.output], + **fsdp_config, + reshard_after_forward=reshard_after_forward_policy == "always", + ) + fully_shard(model, **fsdp_config) + + +def apply_ddp( + model: nn.Module, + dp_mesh: DeviceMesh, + enable_compile: bool, + enable_compiled_autograd: bool, +): + if enable_compile: + if enable_compiled_autograd: + torch._dynamo.config.optimize_ddp = ( + "python_reducer_without_compiled_forward" + ) + else: + torch._dynamo.config.optimize_ddp = "ddp_optimizer" + + replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) + + logger.info("Applied DDP to the model") diff --git a/torchtitan/experiments/vlm/model/args.py b/torchtitan/experiments/vlm/model/args.py new file mode 100644 index 000000000..4bf6583cc --- /dev/null +++ b/torchtitan/experiments/vlm/model/args.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + +from torchtitan.models.llama3 import TransformerModelArgs as Llama3Args + + +@dataclass +class Siglip2ModelArgs: + dim: int = 768 + ffn_dim: int = 3072 + n_layers: int = 12 + n_heads: int = 12 + + n_pos_embs: int = 16 # Number of positional embeddings per h&w + n_channels: int = 3 # RGB channels + patch_size: int = 16 + + layer_norm_eps: float = 1e-6 + use_flex_attn: bool = True + attn_mask_type: str = "causal" + + +@dataclass +class Llama3Siglip2ModelArgs(Llama3Args): + encoder: Siglip2ModelArgs = field(default_factory=Siglip2ModelArgs) + img_token_id: int = 1998 diff --git a/torchtitan/experiments/vlm/model/model.py b/torchtitan/experiments/vlm/model/model.py new file mode 100644 index 000000000..e33d11f9b --- /dev/null +++ b/torchtitan/experiments/vlm/model/model.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import einops as E +import torch +from torch import nn + +from torchtitan.models.attention import init_attention_mask +from torchtitan.models.llama3 import Transformer as Llama3 + +from .args import Llama3Siglip2ModelArgs +from .siglip2 import VisionTransformer + + +class Projector(nn.Module): + """Project the Encoder embedding to the LLM embedding.""" + + def __init__(self, in_dim: int, out_dim: int) -> None: + super().__init__() + self.w1 = nn.Linear(in_dim, in_dim) + self.w2 = nn.Linear(in_dim, out_dim) + self.init_weights() + + def forward(self, x_NLD: torch.Tensor): + x_NLD = self.w1(x_NLD) + x_NLD = nn.functional.silu(x_NLD) + x_NLD = self.w2(x_NLD) + return x_NLD + + def init_weights(self): + nn.init.xavier_uniform_(self.w1.weight) + if self.w1.bias is not None: + nn.init.zeros_(self.w1.bias) + nn.init.xavier_uniform_(self.w2.weight) + if self.w2.bias is not None: + nn.init.zeros_(self.w2.bias) + + +class Llama3Siglip2Transformer(Llama3): + def __init__(self, model_args: Llama3Siglip2ModelArgs): + super().__init__(model_args) + self.model_args = model_args + self.encoder = VisionTransformer(model_args.encoder) + self.projector = Projector( + in_dim=model_args.encoder.dim, out_dim=model_args.dim + ) + self.n_pixels_per_token = model_args.encoder.patch_size**2 + self.init_encoder_weights() + + def init_encoder_weights(self, buffer_device=None): + super().init_weights(buffer_device=buffer_device) + if self.encoder is not None: + self.encoder.init_weights() + if self.projector is not None: + self.projector.init_weights() + + def _scatter_img_tokens(self, h_BSD, tokens_BS, i_NLD, i_mask_NL, img_id=None): + img_id = img_id or self.model_args.img_token_id + B, S, D = h_BSD.shape + # Where are the image tokens in LLM input, make broadcastable with h_BSD + img_mask_h_BSD = E.repeat(tokens_BS == img_id, "b s -> b s 1") + # Only get valid (non-padded) tokens, result are flatten + i_flatten = torch.masked_select(i_NLD, mask=i_mask_NL.unsqueeze(-1)) + + assert i_flatten.numel() // D == img_mask_h_BSD.sum(), ( + f"Different number of visual embeddings {i_flatten.numel() // D} " + f"with placeholder in input token embeddings {img_mask_h_BSD.sum()}" + ) + h_BSD.masked_scatter_(mask=img_mask_h_BSD, source=i_flatten) + return h_BSD + + def forward( + self, + tokens: torch.Tensor, + eos_id: int | None = None, + input_batch: torch.Tensor | None = None, + pixel_values: torch.Tensor | None = None, + grid_thw: torch.Tensor | None = None, + ): + if self.model_args.use_flex_attn: + init_attention_mask( + input_batch if input_batch is not None else tokens, eos_id=self.eos_id + ) + + # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages + h_BSD = self.tok_embeddings(tokens) if self.tok_embeddings else tokens + + if self.encoder is not None: + grid_hw = grid_thw[:, :, 1:] # Siglip2 only support image hw + pixel_masks = E.reduce(grid_hw != -1, "n l hw -> n l", reduction="all") + i_NLD = self.encoder(pixel_values, pixel_masks, grid_hw) + i_NLD = self.projector(i_NLD) + h_BSD = self._scatter_img_tokens(h_BSD, tokens, i_NLD, pixel_masks) + + for layer in self.layers.values(): + h_BSD = layer(h_BSD, self.freqs_cis) + + h_BSD = self.norm(h_BSD) if self.norm else h_BSD + output = self.output(h_BSD) if self.output else h_BSD + return output diff --git a/torchtitan/experiments/vlm/model/siglip2.py b/torchtitan/experiments/vlm/model/siglip2.py new file mode 100644 index 000000000..a1183f7cb --- /dev/null +++ b/torchtitan/experiments/vlm/model/siglip2.py @@ -0,0 +1,221 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import einops as E +import torch +import torch.nn.functional as F +from torch import nn + +from torchtitan.models.attention import build_attention, init_attention_mask + +from .args import Siglip2ModelArgs + + +def resize_positional_embeddings( + pos_embs_HWD: torch.Tensor, + spatial_shapes_N2: torch.Tensor, + max_length: int, +) -> torch.Tensor: + """ + Resize the learned 2D positional embeddings to image-specific size and pad to a fixed size. + + Args: + pos_embs_HWD (`torch.Tensor`): + Position embeddings of shape (height, width, embed_dim) + spatial_shapes (`torch.LongTensor`): + Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to + max_length (`int`): + Maximum length of the positional embeddings to pad resized positional embeddings to + + Returns: + `torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim) + """ + _, _, D = pos_embs_HWD.shape + B, _ = spatial_shapes_N2.shape + + resized_embs_BLD = torch.empty( + (B, max_length, D), + device=pos_embs_HWD.device, + dtype=pos_embs_HWD.dtype, + ) + + # TODO: group images by size, and do interpolate, + # or cache the interpolate output so we do this once per size + for i in range(B): + height, width = spatial_shapes_N2[i].tolist() + if (height + width) == 0: # Skip empty padding images + continue + + resized_emb = F.interpolate( + E.rearrange(pos_embs_HWD, "h w d -> 1 d h w"), + size=(height, width), + mode="bilinear", + align_corners=False, + antialias=True, + ) + + resized_emb_LD = E.rearrange(resized_emb, "1 d h w -> (h w) d") + resized_embs_BLD[i, : int(height * width)] = resized_emb_LD + + return resized_embs_BLD + + +class VisionEmbeddings(nn.Module): + def __init__(self, args: Siglip2ModelArgs): + super().__init__() + self.patch_embedding = nn.Linear( + in_features=args.n_channels * args.patch_size * args.patch_size, + out_features=args.dim, + ) + self.position_embedding = nn.Embedding(args.n_pos_embs**2, args.dim) + self.n_pos_embs = args.n_pos_embs + + def init_weights(self): + nn.init.trunc_normal_(self.patch_embedding.weight, mean=0.0, std=0.02) + nn.init.normal_(self.position_embedding.weight) + + def forward(self, pixels_NLD: torch.Tensor, grid_hw: torch.Tensor) -> torch.Tensor: + # Apply patch embeddings to already patchified pixel values + patch_embeds_NLD = self.patch_embedding(pixels_NLD) + + # Get positional resized and padded positional embeddings + pos_emb_HWD = self.position_embedding.weight.reshape( + self.n_pos_embs, self.n_pos_embs, -1 + ) + spatial_h = E.reduce(grid_hw[:, :, 0], "n l -> n", reduction="max") + 1 + spatial_w = E.reduce(grid_hw[:, :, 1], "n l -> n", reduction="max") + 1 + spatial_shapes = torch.stack([spatial_h, spatial_w], dim=-1).long() + resized_positional_embeddings = resize_positional_embeddings( + pos_emb_HWD, + spatial_shapes, + max_length=pixels_NLD.shape[1], + ) + # Add positional embeddings to patch embeddings + embeddings = patch_embeds_NLD + resized_positional_embeddings + return embeddings + + +class Attention(nn.Module): + """ + Multi-head attention module. + + Args: + model_args (TransformerModelArgs): Model configuration arguments. + + Attributes: + n_heads (int): Number of query heads. + head_dim (int): Dimension size of each attention head. + wq (Linear): Linear transformation for queries. + wk (Linear): Linear transformation for keys. + wv (Linear): Linear transformation for values. + wo (Linear): Linear transformation for output. + + """ + + def __init__(self, args: Siglip2ModelArgs): + super().__init__() + self.dim = args.dim + self.head_dim = args.dim // args.n_heads + + self.q_proj = nn.Linear(self.dim, self.dim) + self.k_proj = nn.Linear(self.dim, self.dim) + self.v_proj = nn.Linear(self.dim, self.dim) + self.out_proj = nn.Linear(self.dim, self.dim) + + self.attn = build_attention( + use_flex_attn=True, attn_mask_type=args.attn_mask_type + ) + + def forward(self, x: torch.Tensor): + xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + # Use self.head_dim instead of `n_heads` to infer the actual + # local heads from sizes of xq, xk, and xv as TP may have sharded them + # after the above linear ops. + xq = E.rearrange(xq, "b l (h d) -> b h l d", d=self.head_dim) + xk = E.rearrange(xk, "b l (h d) -> b h l d", d=self.head_dim) + xv = E.rearrange(xv, "b l (h d) -> b h l d", d=self.head_dim) + + output = self.attn(xq, xk, xv) + output = E.rearrange(output, "b h l d -> b l (h d)").contiguous() + + return self.out_proj(output) + + def init_weights(self): + for linear in (self.q_proj, self.k_proj, self.v_proj, self.out_proj): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + + +class FeedForward(nn.Module): + def __init__(self, args: Siglip2ModelArgs): + super().__init__() + self.fc1 = nn.Linear(args.dim, args.ffn_dim) + self.fc2 = nn.Linear(args.ffn_dim, args.dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc1(x) + x = F.gelu(x, approximate="tanh") + x = self.fc2(x) + return x + + def init_weights(self): + nn.init.trunc_normal_(self.fc1.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.fc2.weight, mean=0.0, std=0.02) + + +class TransformerLayer(nn.Module): + def __init__(self, args: Siglip2ModelArgs): + super().__init__() + self.layer_norm1 = nn.LayerNorm(args.dim, eps=args.layer_norm_eps) + self.self_attn = Attention(args) + self.layer_norm2 = nn.LayerNorm(args.dim, eps=args.layer_norm_eps) + self.mlp = FeedForward(args) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.self_attn(self.layer_norm1(x)) + x = x + self.mlp(self.layer_norm2(x)) + return x + + def init_weights(self): + self.layer_norm1.reset_parameters() + self.layer_norm2.reset_parameters() + self.self_attn.init_weights() + self.mlp.init_weights() + + +class VisionTransformer(nn.Module): + def __init__(self, args: Siglip2ModelArgs): + super().__init__() + self.args = args + self.eos_id = 11 + + self.embeddings = VisionEmbeddings(args) + self.layers = nn.ModuleDict( + {str(idx): TransformerLayer(args) for idx in range(args.n_layers)} + ) + self.post_layernorm = nn.LayerNorm(args.dim, eps=args.layer_norm_eps) + + def forward( + self, + pixel_values_NLD: torch.FloatTensor, + pixel_masks_NL: torch.BoolTensor, + grid_hw: torch.LongTensor, + ): + init_attention_mask(pixel_masks_NL, eos_id=self.eos_id) + + h = self.embeddings(pixel_values_NLD, grid_hw) + + for layer in self.layers.values(): + h = layer(h) + h = self.post_layernorm(h) + + return h + + def init_weights(self): + self.embeddings.init_weights() + for layer in self.layers.values(): + layer.init_weights() + self.post_layernorm.reset_parameters() diff --git a/torchtitan/experiments/vlm/requirements.txt b/torchtitan/experiments/vlm/requirements.txt new file mode 100644 index 000000000..d27fa26c6 --- /dev/null +++ b/torchtitan/experiments/vlm/requirements.txt @@ -0,0 +1 @@ +einops diff --git a/torchtitan/experiments/vlm/train_configs/debug_model.toml b/torchtitan/experiments/vlm/train_configs/debug_model.toml new file mode 100644 index 000000000..1a0e2c5c1 --- /dev/null +++ b/torchtitan/experiments/vlm/train_configs/debug_model.toml @@ -0,0 +1,86 @@ +# torchtitan Config.toml + +[experimental] +custom_args_module = "torchtitan.experiments.vlm.assets.job_config" + +[job] +dump_folder = "./outputs" +description = "Llama 3 Siglip2 VLM debug training" +print_args = false +use_for_integration_test = true + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 10 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "llama3-siglip2" +flavor = "debugmodel" +# test folder with tokenizer.json, for debug purpose only +hf_assets_path = "torchtitan/experiments/vlm/assets/tokenizer" +# converters = ["float8"] + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps +decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps +decay_type = "linear" +min_lr_factor = 0.0 + +[training] +local_batch_size = 32 +seq_len = 4096 +max_patches_per_image = 1024 +max_images_per_batch = 64 +# packing_buffer_size = 100 +max_norm = 1.0 # grad norm clipping +steps = 10 +compile = false +# dataset = "obelics" +dataset = "cc12m" + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 1 +enable_async_tensor_parallel = false +pipeline_parallel_degree = 1 +context_parallel_degree = 1 + +[checkpoint] +enable_checkpoint = false +folder = "checkpoint" +interval = 10 +last_save_model_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = "selective" # ["none", "selective", "full"] +selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output"] + +[validation] +enabled = false +dataset = "c4_validation" +freq = 5 +steps = 10 diff --git a/torchtitan/train.py b/torchtitan/train.py index e38446a39..be8e274ab 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -410,6 +410,7 @@ def forward_backward_step( # apply context parallelism if cp is enabled # ensure CP handles the separate freqs_cis buffer for each pp stage inputs = input_dict["input"] + extra_inputs = {k: v for k, v in input_dict.items() if k != "input"} optional_context_parallel_ctx = ( dist_utils.create_context_parallel_ctx( cp_mesh=parallel_dims.world_mesh["cp"], @@ -430,7 +431,7 @@ def forward_backward_step( ) if self.pp_has_first_stage: self.pp_schedule.step( - inputs, target=targets, losses=losses, input_batch=inputs + inputs, **extra_inputs, target=targets, losses=losses, input_batch=inputs ) else: self.pp_schedule.step( @@ -449,7 +450,7 @@ def forward_backward_step( with self.train_context(optional_context_parallel_ctx): assert len(model_parts) == 1 with self.maybe_enable_amp: - pred = model_parts[0](inputs, eos_id=self.tokenizer.eos_id) + pred = model_parts[0](inputs, eos_id=self.tokenizer.eos_id, **extra_inputs) loss = self.loss_fn(pred, labels) # need to free to before bwd to avoid peaking memory del pred