|
14 | 14 | import torch |
15 | 15 | import torch._dynamo.config |
16 | 16 | import torch._inductor.config |
17 | | -import torch.nn as nn |
| 17 | +import torch.distributed as dist |
18 | 18 |
|
19 | | -from torchchat.model import Model, ModelArgs, ModelType |
| 19 | +from torchchat.distributed.utils import( |
| 20 | + Color as color, |
| 21 | + CUDATrackTime, |
| 22 | + init_distributed, |
| 23 | + GPUMemoryMonitor, |
| 24 | +) |
| 25 | +from torchchat.distributed.logging_utils import SingletonLogger |
20 | 26 |
|
| 27 | +from torchchat.model import Model, ModelArgs, ModelType, Transformer, TransformerArgs |
21 | 28 | from torchchat.model_config.model_config import resolve_model_config |
22 | 29 | from torchchat.utils.build_utils import ( |
23 | 30 | device_sync, |
|
28 | 35 | from torchchat.utils.measure_time import measure_time |
29 | 36 | from torchchat.utils.quantize import quantize_model |
30 | 37 |
|
| 38 | + |
31 | 39 | from torchtune.models.convert_weights import meta_to_tune |
32 | 40 |
|
33 | 41 | from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE |
@@ -598,6 +606,117 @@ def do_nothing(max_batch_size, max_seq_length): |
598 | 606 | model = PTEModel(config, builder_args.pte_path) |
599 | 607 | except Exception: |
600 | 608 | raise RuntimeError(f"Failed to load ET compiled {builder_args.pte_path}") |
| 609 | + elif builder_args.distributed: |
| 610 | + # Using params_table to identify the model to load, for example "Meta-Llama-3.1-8B". |
| 611 | + #TODO This is a hacky way to please the distributed loading api and needs to be replaced |
| 612 | + NAME_TO_DISTRIBUTION = { |
| 613 | + "Meta-Llama-3-8B": "meta-llama/Meta-Llama-3-8B-Instruct", |
| 614 | + "Meta-Llama-3.1-8B": "meta-llama/Meta-Llama-3.1-8B-Instruct", |
| 615 | + "Meta-Llama-3-70B": "meta-llama/Meta-Llama-3-70B-Instruct", |
| 616 | + "Meta-Llama-3.1-70B": "meta-llama/Meta-Llama-3.1-70B-Instruct", |
| 617 | + |
| 618 | + } |
| 619 | + # TODO: Use information in builder_args directly to build model and load weights |
| 620 | + assert builder_args.params_table |
| 621 | + try: |
| 622 | + distribution = NAME_TO_DISTRIBUTION[builder_args.params_table] |
| 623 | + except KeyError as e: |
| 624 | + print(f"Unknown params_table: {builder_args.params_table}. Suported model names are: llama3.1, llama3, llama2-7b-chat") |
| 625 | + raise e |
| 626 | + |
| 627 | + pp_degree = builder_args.pp |
| 628 | + tp_degree = builder_args.tp |
| 629 | + |
| 630 | + init_distributed() |
| 631 | + rank = dist.get_rank() |
| 632 | + torch.cuda.set_device(rank % torch.cuda.device_count()) |
| 633 | + |
| 634 | + logger = SingletonLogger.get_logger() |
| 635 | + |
| 636 | + gpu_memory_monitor = GPUMemoryMonitor("cuda") |
| 637 | + logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}") |
| 638 | + |
| 639 | + # Model-level config |
| 640 | + if builder_args.params_table: |
| 641 | + model_config = ModelArgs.from_table(builder_args.params_table) |
| 642 | + else: |
| 643 | + raise NotImplementedError() |
| 644 | + # Transformer-level config |
| 645 | + config = TransformerArgs.from_params(model_config.transformer_args["text"]) |
| 646 | + logger.info(f"Transformer Config: {config}") |
| 647 | + |
| 648 | + #TODO: Move into head of file after solving circular import |
| 649 | + from torchchat.distributed.checkpoint_utils import ( |
| 650 | + load_model_weights, |
| 651 | + ) |
| 652 | + |
| 653 | + # Validate pipeline degree |
| 654 | + assert config.n_layers % pp_degree == 0 |
| 655 | + |
| 656 | + # Create device mesh |
| 657 | + device_mesh = dist.init_device_mesh( |
| 658 | + "cuda", |
| 659 | + (pp_degree, tp_degree), |
| 660 | + mesh_dim_names=("pp", "tp") |
| 661 | + ) |
| 662 | + tp_mesh = device_mesh["tp"] |
| 663 | + pp_mesh = device_mesh["pp"] |
| 664 | + logger.info(f"Created device mesh: {device_mesh}\n{tp_mesh=}, {pp_mesh=}") |
| 665 | + |
| 666 | + pp_rank = pp_mesh.get_local_rank() |
| 667 | + logger.info(f"{pp_degree=}, {tp_degree=}") |
| 668 | + |
| 669 | + # Assuming same number of GPUs per node |
| 670 | + device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") |
| 671 | + |
| 672 | + # Fill in PP configs |
| 673 | + config.stage_idx = pp_rank |
| 674 | + config.n_stages = pp_degree |
| 675 | + |
| 676 | + with torch.device("meta"): |
| 677 | + # TODO: we should create model instead of Transformer |
| 678 | + model = Transformer(config) |
| 679 | + |
| 680 | + # Distribute model on TP mesh |
| 681 | + # (Surprisingly, this works even though model is on meta device and mesh is of |
| 682 | + # cuda devices) |
| 683 | + model.distribute(tp_mesh) |
| 684 | + if rank == 0: |
| 685 | + logger.info(f"Model: {model}") |
| 686 | + |
| 687 | + # Load weights |
| 688 | + logger.info(f"Loading weights for {pp_rank=} on {device=}") |
| 689 | + with CUDATrackTime() as timer: |
| 690 | + load_model_weights(model, distribution, device, config, builder_args.chpt_from) |
| 691 | + |
| 692 | + logger.info( |
| 693 | + f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" |
| 694 | + ) |
| 695 | + |
| 696 | + # Setup KV caches (after model distribution) |
| 697 | + # The number of cache lanes is the same as the maximum number of |
| 698 | + # micro-batches that can be "in flight" in parallel -- imagine each |
| 699 | + # micro-batch takes 1 "pipeline lane," they need distinct KV cache spaces. |
| 700 | + # When decoding is done for certain micro-batches, we can reuse the KV cache |
| 701 | + # lanes. |
| 702 | + # TODO: bump up the lane count |
| 703 | + pipeline_lanes = 1 |
| 704 | + seqlen_prefill=1024 |
| 705 | + with device: |
| 706 | + model.setup_caches(1, seqlen_prefill, cache_lanes=pipeline_lanes) |
| 707 | + |
| 708 | + # info on stage size and params |
| 709 | + # stage_size = get_module_size(model) |
| 710 | + # stage_size_formatted = bytes_to_readable(stage_size) |
| 711 | + # stage_num_params = get_num_params(model) |
| 712 | + # logger.info( |
| 713 | + # f"Stage {rank} has {color.blue}{stage_num_params} params{color.reset}, Size: {color.blue}{stage_size_formatted}{color.reset}" |
| 714 | + # ) |
| 715 | + model.eval() |
| 716 | + |
| 717 | + model.text_transformer_args = None |
| 718 | + model.config.model_type = model_config.model_type |
| 719 | + model.device_mesh = device_mesh |
601 | 720 | else: |
602 | 721 | with measure_time("Time to load model: {time:.02f} seconds"): |
603 | 722 | model = _load_model(builder_args) |
|
0 commit comments