From 481e00bea0a1385f4f50fc6bf47249a301d61a0b Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Mon, 14 Oct 2024 19:22:42 -0700 Subject: [PATCH 01/24] add pp_dim, distributed, num_gpus, num_nodes as cmd line args --- torchchat/cli/builder.py | 43 +++++++++++++++++++++++++--------------- torchchat/generate.py | 38 +++++++++++++++++------------------ 2 files changed, 45 insertions(+), 36 deletions(-) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 02b1545d0..dcff46bbe 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -16,20 +16,12 @@ import torch._inductor.config import torch.nn as nn -from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune - -from torchchat.distributed import launch_distributed, ParallelDims, parallelize_llama - from torch.distributed.device_mesh import DeviceMesh -from torchtune.models.convert_weights import meta_to_tune - -from torchtune.training import set_default_dtype +from torchchat.distributed import launch_distributed, ParallelDims, parallelize_llama from torchchat.model import Model, ModelArgs, ModelType -from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE - from torchchat.model_config.model_config import resolve_model_config from torchchat.utils.build_utils import ( device_sync, @@ -40,6 +32,14 @@ from torchchat.utils.measure_time import measure_time from torchchat.utils.quantize import quantize_model +from torchtune.models.convert_weights import meta_to_tune + +from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE + +from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune + +from torchtune.training import set_default_dtype + @dataclass class BuilderArgs: @@ -55,7 +55,10 @@ class BuilderArgs: device: Optional[str] = None precision: torch.dtype = torch.float32 setup_caches: bool = False - use_distributed: bool = False + distributed: bool = False + num_gpus: int = 1 + num_nodes: int = 1 + pp_dim: int = 1 is_chat_model: bool = False prefill_possible: bool = False dynamic_shapes: bool = False @@ -156,7 +159,11 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": dtype = torch.float16 else: dtype = name_to_dtype(args.dtype, args.device) - + # distributed args + distributed = getattr(args, "distributed", False) + num_gpus = getattr(args, "num_gpus", 1) + num_nodes = getattr(args, "num_nodes", 1) + pp_dim = getattr(args, "pp_dim", 1) return cls( checkpoint_dir=checkpoint_dir, checkpoint_path=checkpoint_path, @@ -170,7 +177,10 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": device=args.device, precision=dtype, setup_caches=(output_dso_path or output_pte_path), - use_distributed=args.distributed, + distributed=distributed, + num_gpus=num_gpus, + num_nodes=num_nodes, + pp_dim=pp_dim, is_chat_model=is_chat_model, dynamic_shapes=getattr(args, "dynamic_shapes", False), max_seq_length=getattr(args, "max_seq_length", None), @@ -400,10 +410,10 @@ def _load_model_default(builder_args: BuilderArgs) -> Model: # does not host any actual values, need to reinitialize them in the actual # device. Only do those buffer initialization, without initializing the entire # model. - decoder_config = model.config.transformer_args['decoder'] - head_dim = decoder_config['embed_dim'] // decoder_config['num_heads'] - max_seq_len = decoder_config['max_seq_len'] - rope_base = decoder_config['rope_base'] + decoder_config = model.config.transformer_args["decoder"] + head_dim = decoder_config["embed_dim"] // decoder_config["num_heads"] + max_seq_len = decoder_config["max_seq_len"] + rope_base = decoder_config["rope_base"] for submodule in model.modules(): if isinstance(submodule, Llama3ScaledRoPE): submodule.__init__(head_dim, max_seq_len, rope_base) @@ -491,6 +501,7 @@ def _load_model(builder_args: BuilderArgs) -> Model: model = model.to(device=builder_args.device, dtype=builder_args.precision) return model.eval() + def _initialize_model( builder_args: BuilderArgs, quantize, diff --git a/torchchat/generate.py b/torchchat/generate.py index a9094aa40..3d0b9be4e 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -24,15 +24,6 @@ from PIL import Image -# torchtune model definition dependencies -from torchtune.data import Message, padded_collate_tiled_images_and_mask - -from torchtune.generation import sample as tune_sample -from torchtune.models.llama3 import llama3_tokenizer - -from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform -from torchtune.training import set_default_dtype - from torchchat.cli.builder import ( _initialize_model, _initialize_tokenizer, @@ -43,6 +34,15 @@ from torchchat.utils.build_utils import device_sync, set_precision from torchchat.utils.device_info import get_device_info +# torchtune model definition dependencies +from torchtune.data import Message, padded_collate_tiled_images_and_mask + +from torchtune.generation import sample as tune_sample +from torchtune.models.llama3 import llama3_tokenizer + +from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform +from torchtune.training import set_default_dtype + class _ChatFormatter(ABC): def __init__(self, tokenizer): @@ -239,23 +239,17 @@ def __init__( self.is_torchtune_model = generator_args.is_torchtune_model self.dtype = builder_args.precision - # global print - # from tp import maybe_init_dist - # rank = maybe_init_dist() - # use_distributed = False self.rank: Optional[int] = None - # if use_distributed: - # if rank != 0: - # # only print on rank 0 - # print = lambda *args, **kwargs: None print( f"Using device={self.builder_args.device} {get_device_info(self.builder_args.device)}" ) set_precision(self.builder_args.precision) - if builder_args.use_distributed: + if builder_args.distributed: + print(f"Using distributed={builder_args.distributed}") device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") torch.cuda.set_device(device) + assert False, "Distributed is not supported yet" self.is_speculative = self.speculative_builder_args.checkpoint_path is not None if generator_args.chat_mode and not self.builder_args.is_chat_model: @@ -938,7 +932,8 @@ def chat( TransformerCrossAttentionLayer, TransformerSelfAttentionLayer, ) - decoder = self.model.model.decoder + + decoder = self.model.model.decoder for m in reversed(list(decoder.modules())): if isinstance(m, TransformerSelfAttentionLayer) or isinstance( m, TransformerCrossAttentionLayer @@ -984,7 +979,10 @@ def chat( # `is_torchtune_model` is a misnomer since it doesn't capture all # torchtune models (i.e. Flamingo) # See Issue: https://github.com/pytorch/torchchat/issues/1273 - elif not generator_args.is_torchtune_model and self.model.config.model_type != ModelType.Flamingo: + elif ( + not generator_args.is_torchtune_model + and self.model.config.model_type != ModelType.Flamingo + ): max_seq_length = min( encoded.size(0) + generator_args.max_new_tokens, ( From 2f1787ccf76bb15eec00f0e90e04e9fdba8f4812 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Mon, 14 Oct 2024 19:23:55 -0700 Subject: [PATCH 02/24] add tp_dim --- torchchat/cli/builder.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index dcff46bbe..1552bdad2 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -59,6 +59,7 @@ class BuilderArgs: num_gpus: int = 1 num_nodes: int = 1 pp_dim: int = 1 + tp_dim: int = 1 is_chat_model: bool = False prefill_possible: bool = False dynamic_shapes: bool = False @@ -164,6 +165,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": num_gpus = getattr(args, "num_gpus", 1) num_nodes = getattr(args, "num_nodes", 1) pp_dim = getattr(args, "pp_dim", 1) + tp_dim = getattr(args, "tp_dim", 1) return cls( checkpoint_dir=checkpoint_dir, checkpoint_path=checkpoint_path, @@ -181,6 +183,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": num_gpus=num_gpus, num_nodes=num_nodes, pp_dim=pp_dim, + tp_dim=tp_dim, is_chat_model=is_chat_model, dynamic_shapes=getattr(args, "dynamic_shapes", False), max_seq_length=getattr(args, "max_seq_length", None), From fd3ddcd3f58cfd4e5720f624ab8ecb05d1ddc023 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Wed, 16 Oct 2024 17:36:46 -0700 Subject: [PATCH 03/24] add elastic_launch --- torchchat/cli/builder.py | 80 +++++++++++++++++++++++++++++++++++----- torchchat/generate.py | 18 ++++++--- 2 files changed, 83 insertions(+), 15 deletions(-) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 1552bdad2..ace34f1f7 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -16,7 +16,12 @@ import torch._inductor.config import torch.nn as nn +from torch.distributed import launcher + from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.elastic.multiprocessing.errors import record +from torch.distributed.elastic.utils.distributed import get_free_port +from torch.distributed.launcher.api import elastic_launch from torchchat.distributed import launch_distributed, ParallelDims, parallelize_llama @@ -58,8 +63,8 @@ class BuilderArgs: distributed: bool = False num_gpus: int = 1 num_nodes: int = 1 - pp_dim: int = 1 - tp_dim: int = 1 + pp: int = 1 + tp: int = 1 is_chat_model: bool = False prefill_possible: bool = False dynamic_shapes: bool = False @@ -164,8 +169,8 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": distributed = getattr(args, "distributed", False) num_gpus = getattr(args, "num_gpus", 1) num_nodes = getattr(args, "num_nodes", 1) - pp_dim = getattr(args, "pp_dim", 1) - tp_dim = getattr(args, "tp_dim", 1) + pp = getattr(args, "pp", 1) + tp = getattr(args, "tp", 1) return cls( checkpoint_dir=checkpoint_dir, checkpoint_path=checkpoint_path, @@ -182,8 +187,8 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": distributed=distributed, num_gpus=num_gpus, num_nodes=num_nodes, - pp_dim=pp_dim, - tp_dim=tp_dim, + pp=pp, + tp=tp, is_chat_model=is_chat_model, dynamic_shapes=getattr(args, "dynamic_shapes", False), max_seq_length=getattr(args, "max_seq_length", None), @@ -492,19 +497,70 @@ def _maybe_parellelize_model( def _load_model(builder_args: BuilderArgs) -> Model: - world_mesh, parallel_dims = _maybe_init_distributed(builder_args) + # world_mesh, parallel_dims = _maybe_init_distributed(builder_args) if builder_args.gguf_path: model = _load_model_gguf(builder_args) - elif builder_args.use_distributed: - model = _init_model_on_meta_device(builder_args) + # elif builder_args.use_distributed: + # model = _init_model_on_meta_device(builder_args) else: model = _load_model_default(builder_args) - model = _maybe_parellelize_model(model, builder_args, world_mesh, parallel_dims) + # model = _maybe_parellelize_model(model, builder_args, world_mesh, parallel_dims) model = model.to(device=builder_args.device, dtype=builder_args.precision) return model.eval() +@record +def run_main(local_rank): + # Add the directory containing the train file to sys.path + train_file_path = Path(__file__).parent.parent.parent / "dist_run.py" + print(f"******* {train_file_path=}") + sys.path.insert(0, os.path.dirname(os.path.abspath(train_file_path))) + + # Set environment variables for distributed training + os.environ["LOCAL_RANK"] = str(local_rank) + os.environ["RANK"] = str( + local_rank # + kwargs.get("node_rank", 0) * num_processes_per_node + ) + os.environ["WORLD_SIZE"] = str(4 * 1) # num_nodes) + + # Execute the train file + with open(train_file_path, "rb") as file: + exec(compile(file.read(), train_file_path, "exec")) + + +def _launch_distributed_inference(builder_args: BuilderArgs) -> None: + # create programmatic elastic launch + print("Launching distributed inference ...") + + num_processes_per_node = 4 # builder_args.num_gpus + 1 + + lc = launcher.LaunchConfig( + min_nodes=1, + max_nodes=1, + nproc_per_node=num_processes_per_node, + # run_id=str(uuid.uuid4()), + rdzv_backend="c10d", + rdzv_endpoint="localhost:29401", + max_restarts=0, + monitor_interval=1, + ) + + train_file_path = Path(__file__).parent / "distributed" / "dist_run.py" + + elastic_launch( + config=lc, + entrypoint=run_main, + )(train_file_path) + print( + f"Done launching distributed inference on **4 ** {builder_args.num_gpus} GPUs." + ) + # role=role, *args, **kwargs) + + # assert False, "distributed inference is not supported yet" + # pass + + def _initialize_model( builder_args: BuilderArgs, quantize, @@ -513,6 +569,10 @@ def _initialize_model( support_tensor_subclass: bool = True, ) -> Model: print("Loading model...") + if builder_args.distributed: + # we part ways here with torchchat cli and move into dist inference + _launch_distributed_inference(builder_args) + return None if builder_args.gguf_path and (builder_args.dso_path or builder_args.pte_path): print("Setting gguf_kwargs for generate.") diff --git a/torchchat/generate.py b/torchchat/generate.py index 3d0b9be4e..339b4bf85 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -245,11 +245,7 @@ def __init__( f"Using device={self.builder_args.device} {get_device_info(self.builder_args.device)}" ) set_precision(self.builder_args.precision) - if builder_args.distributed: - print(f"Using distributed={builder_args.distributed}") - device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") - torch.cuda.set_device(device) - assert False, "Distributed is not supported yet" + self.is_speculative = self.speculative_builder_args.checkpoint_path is not None if generator_args.chat_mode and not self.builder_args.is_chat_model: @@ -1205,6 +1201,15 @@ def callback(x, *, done_generating=False): print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") +def _launch_distributed_inference( + builder_args: BuilderArgs, +): + from torch.distributed import launcher + from torch.distributed.elastic.utils.distributed import get_free_port + + print("Launching distributed inference within generator") + + def main(args): builder_args = BuilderArgs.from_args(args) speculative_builder_args = BuilderArgs.from_speculative_args(args) @@ -1221,5 +1226,8 @@ def main(args): ) if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats() + if builder_args.distributed: + + return for _ in gen.chat(generator_args): pass From bf79697134332ca036d9ed247df11b164e63de85 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Wed, 16 Oct 2024 18:37:12 -0700 Subject: [PATCH 04/24] working, can now launch from cli --- dist_run.py | 16 +++++++++----- torchchat/cli/builder.py | 48 ++++++++++++++++++++++++++-------------- 2 files changed, 41 insertions(+), 23 deletions(-) diff --git a/dist_run.py b/dist_run.py index 9af0fe154..72fc5855a 100644 --- a/dist_run.py +++ b/dist_run.py @@ -20,14 +20,14 @@ from torch.distributed.pipelining import PipelineStage, ScheduleGPipe from torchchat.cli.builder import _initialize_tokenizer, TokenizerArgs -from torchchat.distributed.logging_utils import SingletonLogger - # TODO - these are not distributed specific, consider moving to new package from torchchat.distributed.checkpoint_utils import ( get_hf_config_file, load_weights_from_hf_format, load_weights_from_torchchat_format, ) + +from torchchat.distributed.logging_utils import SingletonLogger from torchchat.distributed.utils import ( bytes_to_readable, Color as color, @@ -153,7 +153,9 @@ def _load_model_weights( # This format stands for: # single binary file, OR # multiple binary files without index files. - load_weights_from_torchchat_format(stage_module, distribution, device, model_config) + load_weights_from_torchchat_format( + stage_module, distribution, device, model_config + ) else: raise ValueError(f"Unknown checkpoint format: {chpt_from}") @@ -304,7 +306,7 @@ def _cleanup(): def main(args): - model_name = args.model_name + model_name = "llama3" # args.model_name pp_degree = args.pp rank, world_size = _init_distributed() @@ -590,12 +592,14 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument( + """parser.add_argument( "model_name", type=str, + default="llama3", help="Name of the model to load", - choices=NAME_TO_DISTRIBUTION_AND_DTYPE.keys(), + # choices=NAME_TO_DISTRIBUTION_AND_DTYPE.keys(), ) + """ parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel degree") parser.add_argument( "--ntokens", diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index ace34f1f7..3553702d0 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -510,23 +510,29 @@ def _load_model(builder_args: BuilderArgs) -> Model: return model.eval() -@record -def run_main(local_rank): - # Add the directory containing the train file to sys.path - train_file_path = Path(__file__).parent.parent.parent / "dist_run.py" - print(f"******* {train_file_path=}") - sys.path.insert(0, os.path.dirname(os.path.abspath(train_file_path))) +import importlib.util +import subprocess + - # Set environment variables for distributed training - os.environ["LOCAL_RANK"] = str(local_rank) - os.environ["RANK"] = str( - local_rank # + kwargs.get("node_rank", 0) * num_processes_per_node +def run_script(script_path, *args): + # Construct the command to run the script + cmd = [sys.executable, script_path] + list(args) + + # Run the script as a subprocess + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True ) - os.environ["WORLD_SIZE"] = str(4 * 1) # num_nodes) - # Execute the train file - with open(train_file_path, "rb") as file: - exec(compile(file.read(), train_file_path, "exec")) + # Stream the output in real-time + for line in process.stdout: + print(line, end="") + for line in process.stderr: + print(line, end="", file=sys.stderr) + + # Wait for the process to complete and get the return code + return_code = process.wait() + if return_code != 0: + raise subprocess.CalledProcessError(return_code, cmd) def _launch_distributed_inference(builder_args: BuilderArgs) -> None: @@ -546,12 +552,20 @@ def _launch_distributed_inference(builder_args: BuilderArgs) -> None: monitor_interval=1, ) - train_file_path = Path(__file__).parent / "distributed" / "dist_run.py" + train_file_path = Path(__file__).parent.parent.parent / "dist_run.py" + print(f"train_file_path: {train_file_path}") + # import argparse + + # parser2 = argparse.ArgumentParser() + + # args = parser2.parse_args() + args = [] + print(f"args: {args}") elastic_launch( config=lc, - entrypoint=run_main, - )(train_file_path) + entrypoint=run_script, + )(train_file_path, *args) print( f"Done launching distributed inference on **4 ** {builder_args.num_gpus} GPUs." ) From 26a94550a1092d98a1b3e61c9c62076187e5f436 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Wed, 16 Oct 2024 09:49:11 -0700 Subject: [PATCH 05/24] Remove numpy < 2.0 pin to align with pytorch (#1301) Fix #1296 Align with https://github.com/pytorch/pytorch/blame/main/requirements.txt#L5 --- install/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/install/requirements.txt b/install/requirements.txt index 3329563b4..e22fdedf3 100644 --- a/install/requirements.txt +++ b/install/requirements.txt @@ -12,7 +12,7 @@ tiktoken # Miscellaneous snakeviz sentencepiece -numpy < 2.0 +numpy gguf blobfile tomli >= 1.1.0 ; python_version < "3.11" From 5f0ca00ad9193801a35c3aa5635213da61424cac Mon Sep 17 00:00:00 2001 From: vmpuri <45368418+vmpuri@users.noreply.github.com> Date: Wed, 16 Oct 2024 13:42:56 -0700 Subject: [PATCH 06/24] Update torchtune pin to 0.4.0-dev20241010 (#1300) Co-authored-by: vmpuri --- install/install_requirements.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/install/install_requirements.sh b/install/install_requirements.sh index a05e255db..6344509d8 100755 --- a/install/install_requirements.sh +++ b/install/install_requirements.sh @@ -53,7 +53,7 @@ PYTORCH_NIGHTLY_VERSION=dev20241002 VISION_NIGHTLY_VERSION=dev20241002 # Nightly version for torchtune -TUNE_NIGHTLY_VERSION=dev20240928 +TUNE_NIGHTLY_VERSION=dev20241010 # Uninstall triton, as nightly will depend on pytorch-triton, which is one and the same ( @@ -78,7 +78,7 @@ fi REQUIREMENTS_TO_INSTALL=( torch=="2.6.0.${PYTORCH_NIGHTLY_VERSION}" torchvision=="0.20.0.${VISION_NIGHTLY_VERSION}" - torchtune=="0.3.0.${TUNE_NIGHTLY_VERSION}" + torchtune=="0.4.0.${TUNE_NIGHTLY_VERSION}" ) # Install the requirements. --extra-index-url tells pip to look for package From 598caf5374e71e937e71e703b035586d6fe4e2b6 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Wed, 16 Oct 2024 15:22:10 -0700 Subject: [PATCH 07/24] Unbreak gguf util CI job by fixing numpy version (#1307) Setting numpy version to be the range required by gguf: https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/pyproject.toml --- install/requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/install/requirements.txt b/install/requirements.txt index e22fdedf3..bda626257 100644 --- a/install/requirements.txt +++ b/install/requirements.txt @@ -12,7 +12,8 @@ tiktoken # Miscellaneous snakeviz sentencepiece -numpy +# numpy version range required by GGUF util +numpy >= 1.17, < 2.0 gguf blobfile tomli >= 1.1.0 ; python_version < "3.11" From 6fe1646cfd422d435b7e40f6f692642c084fd17c Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Wed, 16 Oct 2024 22:54:26 -0700 Subject: [PATCH 08/24] Remove apparently-unused import torchvision in model.py (#1305) Co-authored-by: vmpuri <45368418+vmpuri@users.noreply.github.com> --- torchchat/model.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchchat/model.py b/torchchat/model.py index 25b4ddcd7..7868b6593 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -12,8 +12,6 @@ from enum import Enum from pathlib import Path -import torchvision - from typing import Any, Callable, Dict, Optional, Union from collections.abc import Hashable From 78debce4b92cb4c4b0eb58130a0874500ae76302 Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Thu, 17 Oct 2024 12:24:55 -0700 Subject: [PATCH 09/24] remove global var for tokenizer type + patch tokenizer to allow list of sequences --- dist_run.py | 54 +++++++++++++++++++++++++---------------------------- 1 file changed, 25 insertions(+), 29 deletions(-) diff --git a/dist_run.py b/dist_run.py index 72fc5855a..33ba71fc9 100644 --- a/dist_run.py +++ b/dist_run.py @@ -12,13 +12,13 @@ import os from enum import auto, Enum from pathlib import Path -from types import SimpleNamespace +from types import SimpleNamespace, MethodType from typing import Any, Dict, List, Optional, Tuple import torch import torch.distributed as dist from torch.distributed.pipelining import PipelineStage, ScheduleGPipe -from torchchat.cli.builder import _initialize_tokenizer, TokenizerArgs +from torchchat.cli.builder import TokenizerArgs # TODO - these are not distributed specific, consider moving to new package from torchchat.distributed.checkpoint_utils import ( @@ -50,7 +50,6 @@ logger = SingletonLogger.get_logger() -_tokenizer_type = None # global variable to store the tokenizer type # Using model name to identify the model to load, for example "llama2-7b-chat". # You can change it to other values listed below. @@ -61,11 +60,6 @@ } -class TokenizerType(Enum): - Tiktoken = auto() - SentencePiece = auto() - - def _init_distributed(): dist.init_process_group("nccl") rank = dist.get_rank() @@ -82,14 +76,29 @@ def _create_device_mesh(mesh_dimensions): def dict_to_args(dictionary: Dict[str, Any]) -> SimpleNamespace: return SimpleNamespace(**dictionary) +def _patch_tokenizer(tokenizer): + """Patch the tokenizer to support decoding of token ids.""" + if isinstance(tokenizer, TiktokenTokenizer): + # Patch tiktokenizer to allow a list of sequences. + #TODO: Upstream to tokenizer modules + old_decode = tokenizer.decode + + def decode(self, token_ids: List[int|List[int]], *args, **kwargs) -> str | List[str]: + if len(token_ids)<1: + return "" + if isinstance(token_ids[0], list): + return [old_decode(t, *args, **kwargs) for t in token_ids] + else: + return old_decode(token_ids, *args, **kwargs) + + tokenizer.decode = MethodType(decode, tokenizer) + return tokenizer def _build_chat_tokenizer( model_name: str, model_base_name: Optional[str] = None, ) -> SentencePieceProcessor | TiktokenTokenizer: - """Builds a tokenizer for the given model name, and sets the global tokenizer type variable""" - - global _tokenizer_type + """Builds a tokenizer for the given model name""" # Try to infer the model base name from the model name: # e.g. "llama2-7b-chat" -> "llama2" @@ -112,20 +121,14 @@ def _build_chat_tokenizer( } args = dict_to_args(tokenconfig) tokenizer_args = TokenizerArgs.from_args(args) - tokenizer = _initialize_tokenizer(tokenizer_args) + tokenizer = tokenizer_args.t assert tokenizer is not None, f"Failed to get tokenizer using {tokenconfig=}" logger.info( f"using tokenizer = {tokenizer.__class__.__module__}.{tokenizer.__class__.__name__}" ) - # set global variable _tokenizer_type - if isinstance(tokenizer, TiktokenTokenizer): - _tokenizer_type = TokenizerType.Tiktoken - elif isinstance(tokenizer, SentencePieceProcessor): - _tokenizer_type = TokenizerType.SentencePiece - else: - raise ValueError(f"Unknown tokenizer type: {tokenizer.__class__}") - logger.info(f"tokenizer type = {_tokenizer_type}") + tokenizer = _patch_tokenizer(tokenizer) + return tokenizer @@ -568,15 +571,8 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: # token ids. Thus cat'ing along dim 1. res = torch.cat(res, dim=1) res_list = res.tolist() - if _tokenizer_type == TokenizerType.Tiktoken: - # For TiktokenTokenizer, we need to decode prompt by prompt. - # TODO: is there a better way to do this? - responses = [tokenizer.decode(sequence) for sequence in res_list] - elif _tokenizer_type == TokenizerType.SentencePiece: # SentencePieceProcessor - # For SentencePieceProcessor, we can decode the entire 2D list at once. - responses = tokenizer.decode(res_list) - else: - raise ValueError(f"Unknown tokenizer type {_tokenizer_type}") + + responses = tokenizer.decode(res_list) # Show prompts and responses for prompt_text, response_text in zip(prompt, responses): From 2eefb1347a52ac98db829b5cd87c20990b455183 Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Thu, 17 Oct 2024 12:27:07 -0700 Subject: [PATCH 10/24] make pp tp visible in interface --- dist_run.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/dist_run.py b/dist_run.py index 33ba71fc9..11e87b3f2 100644 --- a/dist_run.py +++ b/dist_run.py @@ -69,8 +69,8 @@ def _init_distributed(): return rank, world_size -def _create_device_mesh(mesh_dimensions): - return dist.init_device_mesh("cuda", mesh_dimensions, mesh_dim_names=("pp", "tp")) +def _create_device_mesh(pp_degree, tp_degree): + return dist.init_device_mesh("cuda", (pp_degree, tp_degree), mesh_dim_names=("pp", "tp")) def dict_to_args(dictionary: Dict[str, Any]) -> SimpleNamespace: @@ -343,8 +343,7 @@ def main(args): tp_degree = world_size // pp_degree # Create device mesh - mesh_dimensions = (pp_degree, tp_degree) - device_mesh = _create_device_mesh(mesh_dimensions) + device_mesh = _create_device_mesh(pp_degree, tp_degree) tp_mesh = device_mesh["tp"] pp_mesh = device_mesh["pp"] logger.info(f"Created device mesh: {device_mesh}\n{tp_mesh=}, {pp_mesh=}") From e8bb0764c4d9e0dacdc89caf767eb469342dd228 Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Thu, 17 Oct 2024 12:42:33 -0700 Subject: [PATCH 11/24] Add llama 3.1 to dist_run.py --- dist_run.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dist_run.py b/dist_run.py index 11e87b3f2..42e3ec5e4 100644 --- a/dist_run.py +++ b/dist_run.py @@ -57,6 +57,7 @@ NAME_TO_DISTRIBUTION_AND_DTYPE = { "llama2-7b-chat": ("meta-llama/Llama-2-7b-chat-hf", torch.float16), "llama3": ("meta-llama/Meta-Llama-3-8B-Instruct", torch.bfloat16), + "llama3.1": ("meta-llama/Meta-Llama-3.1-8B-Instruct", torch.bfloat16), } From 1faa0528bcf50610a123b75dacf7887472da5fa5 Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Fri, 18 Oct 2024 14:13:46 -0700 Subject: [PATCH 12/24] [WIP] Move dist inf into its own generator --- torchchat/cli/builder.py | 72 +--------- torchchat/cli/cli.py | 16 +++ .../distributed/dist_run.py | 0 torchchat/distributed/generate.py | 128 ++++++++++++++++++ torchchat/generate.py | 49 +++++-- 5 files changed, 179 insertions(+), 86 deletions(-) rename dist_run.py => torchchat/distributed/dist_run.py (100%) create mode 100644 torchchat/distributed/generate.py diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 3553702d0..49ce9a077 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -508,72 +508,7 @@ def _load_model(builder_args: BuilderArgs) -> Model: model = model.to(device=builder_args.device, dtype=builder_args.precision) return model.eval() - - -import importlib.util -import subprocess - - -def run_script(script_path, *args): - # Construct the command to run the script - cmd = [sys.executable, script_path] + list(args) - - # Run the script as a subprocess - process = subprocess.Popen( - cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True - ) - - # Stream the output in real-time - for line in process.stdout: - print(line, end="") - for line in process.stderr: - print(line, end="", file=sys.stderr) - - # Wait for the process to complete and get the return code - return_code = process.wait() - if return_code != 0: - raise subprocess.CalledProcessError(return_code, cmd) - - -def _launch_distributed_inference(builder_args: BuilderArgs) -> None: - # create programmatic elastic launch - print("Launching distributed inference ...") - - num_processes_per_node = 4 # builder_args.num_gpus + 1 - - lc = launcher.LaunchConfig( - min_nodes=1, - max_nodes=1, - nproc_per_node=num_processes_per_node, - # run_id=str(uuid.uuid4()), - rdzv_backend="c10d", - rdzv_endpoint="localhost:29401", - max_restarts=0, - monitor_interval=1, - ) - - train_file_path = Path(__file__).parent.parent.parent / "dist_run.py" - print(f"train_file_path: {train_file_path}") - # import argparse - - # parser2 = argparse.ArgumentParser() - - # args = parser2.parse_args() - args = [] - print(f"args: {args}") - - elastic_launch( - config=lc, - entrypoint=run_script, - )(train_file_path, *args) - print( - f"Done launching distributed inference on **4 ** {builder_args.num_gpus} GPUs." - ) - # role=role, *args, **kwargs) - - # assert False, "distributed inference is not supported yet" - # pass - + def _initialize_model( builder_args: BuilderArgs, @@ -583,11 +518,6 @@ def _initialize_model( support_tensor_subclass: bool = True, ) -> Model: print("Loading model...") - if builder_args.distributed: - # we part ways here with torchchat cli and move into dist inference - _launch_distributed_inference(builder_args) - return None - if builder_args.gguf_path and (builder_args.dso_path or builder_args.pte_path): print("Setting gguf_kwargs for generate.") is_dso = builder_args.dso_path is not None diff --git a/torchchat/cli/cli.py b/torchchat/cli/cli.py index 1d624c6c4..fd9f2963d 100644 --- a/torchchat/cli/cli.py +++ b/torchchat/cli/cli.py @@ -409,6 +409,22 @@ def _add_distributed_args(parser) -> None: help=argparse.SUPPRESS, # "Use the specified model checkpoint directory", ) + parser.add_argument( + "--pp", + "--pipeline-parallel", + type=int, + default=1, + help=argparse.SUPPRESS, + # "Pipeline parallel degree", + ) + parser.add_argument( + "--tp", + "--tensor-parallel", + type=int, + default=1, + help=argparse.SUPPRESS, + # "Tensor parallel degree", + ) # Add CLI Args related to custom model inputs diff --git a/dist_run.py b/torchchat/distributed/dist_run.py similarity index 100% rename from dist_run.py rename to torchchat/distributed/dist_run.py diff --git a/torchchat/distributed/generate.py b/torchchat/distributed/generate.py new file mode 100644 index 000000000..7bec30db5 --- /dev/null +++ b/torchchat/distributed/generate.py @@ -0,0 +1,128 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +from abc import abstractmethod +from typing import List, Optional +from dataclasses import dataclass +from pathlib import Path +from torchchat.cli.builder import BuilderArgs, TokenizerArgs + + +import importlib.util +import subprocess + + +def run_script(script_path, *args): + # Construct the command to run the script + cmd = [sys.executable, script_path] + list(args) + + # Run the script as a subprocess + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True + ) + + # Stream the output in real-time + for line in process.stdout: + print(line, end="") + for line in process.stderr: + print(line, end="", file=sys.stderr) + + # Wait for the process to complete and get the return code + return_code = process.wait() + if return_code != 0: + raise subprocess.CalledProcessError(return_code, cmd) + + +def _launch_distributed_inference(builder_args: BuilderArgs) -> None: + # create programmatic elastic launch + print("Launching distributed inference ...") + + num_processes_per_node = 4 # builder_args.num_gpus + 1 + + lc = launcher.LaunchConfig( + min_nodes=1, + max_nodes=1, + nproc_per_node=num_processes_per_node, + # run_id=str(uuid.uuid4()), + rdzv_backend="c10d", + rdzv_endpoint="localhost:29401", + max_restarts=0, + monitor_interval=1, + ) + + # train_file_path = Path(__file__).parent.parent.parent / "dist_run.py" + # print(f"train_file_path: {train_file_path}") + # import argparse + + # parser2 = argparse.ArgumentParser() + + # args = parser2.parse_args() + args = [] + print(f"args: {args}") + + from dist_run import main + + elastic_launch( + config=lc, + entrypoint=run_script, + )(main, *args) + print( + f"Done launching distributed inference on **4 ** {builder_args.num_gpus} GPUs." + ) + # role=role, *args, **kwargs) + + # assert False, "distributed inference is not supported yet" + # pass + +@dataclass +class Output: + request_id: int + is_finished: bool = False + output: Optional[str] = None + +class Generator(object): + + @abstractmethod + def add_request(self, request_id: int, prompt: str): + raise NotImplementedError() + + def step(self) -> List[Output]: + raise NotImplementedError() + + +class DistributedGenerator(Generator): + def __init__( + self, + builder_args: BuilderArgs, + speculative_builder_args: BuilderArgs, + tokenizer_args: TokenizerArgs, + #TODO: move GeneratorArgs into a different module + # generator_args: GeneratorArgs, + profile: Optional[Path], + quantize: bool, + draft_quantize: bool, + ): + self.requests = {} + # if builder_args.distributed: + # # we part ways here with torchchat cli and move into dist inference + _launch_distributed_inference(builder_args) + # return None + + + def add_request(self, request_id: int, prompt: str): + assert request_id not in self.requests + self.requests[request_id] = prompt + + + def step(self) -> List[Output]: + outputs = [] + for request_id, prompt in self.requests.items(): + outputs.append(Output(request_id, is_finished=True, output=prompt)) + + for output in outputs: + if output.is_finished: + del self.requests[output.request_id] + + return outputs diff --git a/torchchat/generate.py b/torchchat/generate.py index 339b4bf85..c1faeb830 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -31,6 +31,7 @@ TokenizerArgs, ) from torchchat.model import Model, ModelType +from torchchat.distributed.generate import DistributedGenerator from torchchat.utils.build_utils import device_sync, set_precision from torchchat.utils.device_info import get_device_info @@ -1215,19 +1216,37 @@ def main(args): speculative_builder_args = BuilderArgs.from_speculative_args(args) tokenizer_args = TokenizerArgs.from_args(args) generator_args = GeneratorArgs.from_args(args) - gen = Generator( - builder_args, - speculative_builder_args, - tokenizer_args, - generator_args, - args.profile, - args.quantize, - args.draft_quantize, - ) - if torch.cuda.is_available(): - torch.cuda.reset_peak_memory_stats() - if builder_args.distributed: + if not builder_args.distributed: + gen = Generator( + builder_args, + speculative_builder_args, + tokenizer_args, + generator_args, + args.profile, + args.quantize, + args.draft_quantize, + ) + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats() + + + for _ in gen.chat(generator_args): + pass + else: + dist_gen = DistributedGenerator( + builder_args, + speculative_builder_args, + tokenizer_args, + # generator_args, + args.profile, + args.quantize, + args.draft_quantize, + ) + + dist_gen.add_request(0, "Tell me a joke") + dist_gen.add_request(1, "Tell me another joke") - return - for _ in gen.chat(generator_args): - pass + outputs = dist_gen.step() + while len(outputs): + print(outputs) + outputs = dist_gen.step() From 11f29fc22d57a069dbc18f0874234e2d3e40a5ae Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Mon, 21 Oct 2024 16:32:58 -0700 Subject: [PATCH 13/24] Add initial generator interface to dist inference --- torchchat/cli/builder.py | 11 +- torchchat/cli/cli.py | 14 ++ torchchat/distributed/dist_run.py | 266 ++++++++++++++++-------------- torchchat/distributed/generate.py | 121 ++++++++------ torchchat/generate.py | 18 +- 5 files changed, 247 insertions(+), 183 deletions(-) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 49ce9a077..11b465e2a 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -16,12 +16,9 @@ import torch._inductor.config import torch.nn as nn -from torch.distributed import launcher - from torch.distributed.device_mesh import DeviceMesh from torch.distributed.elastic.multiprocessing.errors import record from torch.distributed.elastic.utils.distributed import get_free_port -from torch.distributed.launcher.api import elastic_launch from torchchat.distributed import launch_distributed, ParallelDims, parallelize_llama @@ -65,6 +62,8 @@ class BuilderArgs: num_nodes: int = 1 pp: int = 1 tp: int = 1 + chpt_from: str = "hf" + ntokens: int = 40 is_chat_model: bool = False prefill_possible: bool = False dynamic_shapes: bool = False @@ -171,6 +170,8 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": num_nodes = getattr(args, "num_nodes", 1) pp = getattr(args, "pp", 1) tp = getattr(args, "tp", 1) + chpt_from = getattr(args, "chpt_from", "hf") + ntokens = getattr(args, "ntokens", 40) return cls( checkpoint_dir=checkpoint_dir, checkpoint_path=checkpoint_path, @@ -189,6 +190,8 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": num_nodes=num_nodes, pp=pp, tp=tp, + chpt_from=chpt_from, + ntokens=ntokens, is_chat_model=is_chat_model, dynamic_shapes=getattr(args, "dynamic_shapes", False), max_seq_length=getattr(args, "max_seq_length", None), @@ -508,7 +511,7 @@ def _load_model(builder_args: BuilderArgs) -> Model: model = model.to(device=builder_args.device, dtype=builder_args.precision) return model.eval() - + def _initialize_model( builder_args: BuilderArgs, diff --git a/torchchat/cli/cli.py b/torchchat/cli/cli.py index fd9f2963d..dad073a14 100644 --- a/torchchat/cli/cli.py +++ b/torchchat/cli/cli.py @@ -426,6 +426,20 @@ def _add_distributed_args(parser) -> None: # "Tensor parallel degree", ) + parser.add_argument( + "--ntokens", + type=int, + default=40, + help="Number of tokens to generate", + ) + parser.add_argument( + "--chpt-from", + type=str, + default="hf", # TODO: change to torchchat once we support it well + help="Checkpoint format to load from", + choices=["hf", "torchchat"], + ) + # Add CLI Args related to custom model inputs def _add_custom_model_args(parser) -> None: diff --git a/torchchat/distributed/dist_run.py b/torchchat/distributed/dist_run.py index 42e3ec5e4..6a74e95d1 100644 --- a/torchchat/distributed/dist_run.py +++ b/torchchat/distributed/dist_run.py @@ -221,8 +221,7 @@ def _create_padded_prompts( def _batch_decode_next_tokens( output: torch.Tensor, - pos: List[int], - step: int = -1, + pos: List[int]=None, temperature: float = 1.0, topk: int = 10, ) -> torch.Tensor: @@ -240,7 +239,7 @@ def _batch_decode_next_tokens( """ batch_size, seq_len, vocab_size = output.shape - if step != -1: + if pos is None: # `pos` is not provided, so we can use the first token next_token_logits = output[:, 0, :] else: @@ -283,8 +282,6 @@ def _decode_in_flight(token, tokenizer, tp_rank): # `token` is a tensor of shape (batch_size, 1). # For TiktokenTokenizer, we need to squeeze it to 1D. # For SentencePieceProcessor, we don't. - if isinstance(tokenizer, TiktokenTokenizer): - token = torch.squeeze(token, dim=1) token_str = tokenizer.decode(token.tolist()) # print the token string on tp rank 0 if tp_rank == 0: @@ -292,6 +289,7 @@ def _decode_in_flight(token, tokenizer, tp_rank): f"{color.green} responses ====>>>> " f"{color.blue} {token_str} {color.reset}" ) + return token_str def _cleanup(): @@ -299,7 +297,7 @@ def _cleanup(): dist.destroy_process_group() -prompt = [ +prompts = [ "What is Snow?", # "Can you explain what is the purpose of back propagation in neural networks?", "Who is Santa Claus?", @@ -309,11 +307,12 @@ def _cleanup(): ] -def main(args): +def main(args, pipe): model_name = "llama3" # args.model_name pp_degree = args.pp rank, world_size = _init_distributed() + logger.info(f"Worker started: {rank=}, {world_size=}") gpu_memory_monitor = GPUMemoryMonitor("cuda") logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}") @@ -389,7 +388,7 @@ def main(args): # Batch size. Since we push batches dynamically through the pipeline rather # than chunking them, this is effectively micro-batch size in pipeline # sense. Thus it is interchangeable with micro-batch size below. - batch_size = len(prompt) + batch_size = 4# len(prompt) seqlen_prefill = 1024 # sequence length dim = 4096 # embedding dimension @@ -451,134 +450,161 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: # pipelining effect. prefiller = ScheduleGPipe(prefill_stage, 1) - start_pos = 0 - # Need these global ids due to the API definition of dist.send and recv first_pp_rank_global_id = dist.get_global_rank(pp_group, first_pp_rank) last_pp_rank_global_id = dist.get_global_rank(pp_group, last_pp_rank) - # encode the prompt - input_ids = _encode_strings( - prompt, tokenizer, bos=True, device=device, dtype=torch.int64 - ) + pipe.send("ready") - # create a padded tensor for the input prompt - padded_sequence, prompt_lengths = _create_padded_prompts( - input_ids, tokenizer, seqlen_prefill, start_pos, device - ) + while True: + command = pipe.recv() + assert isinstance(command, (str, list)) + if isinstance(command, str): + if command == "stop": + break + else: + raise ValueError(f"Unknown command: {command}") + else: + prompt = command + assert len(prompt) == batch_size, f"Expecting {batch_size=} prompts but got {len(prompt)=}" + logger.info(f"{color.green}Prompt: {prompt}{color.reset}") - # Need these global ids due to the API definition of dist.send and recv - first_pp_rank_global_id = dist.get_global_rank(pp_group, first_pp_rank) - last_pp_rank_global_id = dist.get_global_rank(pp_group, last_pp_rank) + start_pos = 0 - # New token generated each iteration - # need a row dimension for each prompt in the batch - new_token = torch.zeros(batch_size, 1, device=device, dtype=torch.int64) - # Store the generated tokens - res = [] - - # Prefill phase - # Run context input through pipeline - # TODO: we need to pass `input_pos` and `cache_lane` to each stage. - lane = 0 - kwargs = {"input_pos": input_pos, "cache_lane": lane} - with torch.no_grad(), CUDATrackTime() as timer: - if pp_rank == first_pp_rank: - output = prefiller.step(padded_sequence, **kwargs) - elif pp_rank == last_pp_rank: - output = prefiller.step(**kwargs) - else: # middle pp ranks - prefiller.step(**kwargs) + # encode the prompt + input_ids = _encode_strings( + prompt, tokenizer, bos=True, device=device, dtype=torch.int64 + ) - logger.info( - f"{color.green}Prefilling time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" - ) + # create a padded tensor for the input prompt + padded_sequence, prompt_lengths = _create_padded_prompts( + input_ids, tokenizer, seqlen_prefill, start_pos, device + ) - # Decode the output -- first generated token - if pp_rank == last_pp_rank: - logger.info(f"{color.green}Decoding...{prompt_lengths=}{color.reset}") - new_token = _batch_decode_next_tokens(output, prompt_lengths) - res.append(new_token) - if not args.disable_in_flight_decode: - _decode_in_flight(new_token, tokenizer, tp_rank) - - # seqlen = 1 now - seqlen_decode = 1 - input_pos = torch.tensor([prompt_lengths[0]], device=device) - - # Create decode stage - logger.info(f"Creating pipeline stage for decode {pp_rank=}, {pp_degree=}") - example_inputs, example_outputs = get_example_ins_outs(seqlen_decode) - decode_stage = PipelineStage( - model, - pp_rank, - pp_degree, - device, - input_args=example_inputs, - output_args=example_outputs, - group=pp_group, - ) - # create schedule - decoder = ScheduleGPipe(decode_stage, 1) - - # Decoding - with torch.no_grad(), CUDATrackTime() as timer: - for step in range(args.ntokens - 1): - kwargs = {"input_pos": input_pos, "cache_lane": lane} - # sendrecv between last and first ranks, only if: - # first_pp_rank != last_pp_rank. - if pp_rank == last_pp_rank and pp_rank != first_pp_rank: - dist.send( - new_token, - dst=first_pp_rank_global_id, - group=pp_group, - ) - elif pp_rank == first_pp_rank and pp_rank != last_pp_rank: - dist.recv( - new_token, - src=last_pp_rank_global_id, - group=pp_group, - ) - - # Run data through pipeline + # New token generated each iteration + # need a row dimension for each prompt in the batch + new_token = torch.zeros(batch_size, 1, device=device, dtype=torch.int64) + # Store the generated tokens + res = [] + + # Prefill phase + # Run context input through pipeline + # TODO: we need to pass `input_pos` and `cache_lane` to each stage. + lane = 0 + kwargs = {"input_pos": input_pos, "cache_lane": lane} + with torch.no_grad(), CUDATrackTime() as timer: if pp_rank == first_pp_rank: - output = decoder.step(new_token, **kwargs) + output = prefiller.step(padded_sequence, **kwargs) elif pp_rank == last_pp_rank: - output = decoder.step(**kwargs) + output = prefiller.step(**kwargs) else: # middle pp ranks - decoder.step(**kwargs) - - # Decode the output - if pp_rank == last_pp_rank: - new_token = _batch_decode_next_tokens(output, prompt_lengths, step) - res.append(new_token) - if not args.disable_in_flight_decode: - _decode_in_flight(new_token, tokenizer, tp_rank) + prefiller.step(**kwargs) - # Increment input position - input_pos += 1 + logger.info( + f"{color.green}Prefilling time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" + ) - logger.info( - f"{color.green}Decoding time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" - ) + # Decode the output -- first generated token + if pp_rank == last_pp_rank: + logger.info(f"{color.green}Decoding...{prompt_lengths=}{color.reset}") + new_token = _batch_decode_next_tokens(output, prompt_lengths) + res.append(new_token) + #TODO: Move to a separate decoding thread + resp = _decode_in_flight(new_token, tokenizer, tp_rank) + pipe.send(resp) + else: + logger.info(f"sending None {tp_rank=}") + pipe.send(None) + + # seqlen = 1 now + seqlen_decode = 1 + input_pos = torch.tensor([prompt_lengths[0]], device=device) + + # Create decode stage + logger.info(f"Creating pipeline stage for decode {pp_rank=}, {pp_degree=}") + example_inputs, example_outputs = get_example_ins_outs(seqlen_decode) + decode_stage = PipelineStage( + model, + pp_rank, + pp_degree, + device, + input_args=example_inputs, + output_args=example_outputs, + group=pp_group, + ) + # create schedule + decoder = ScheduleGPipe(decode_stage, 1) + + # Decoding + with torch.no_grad(), CUDATrackTime() as timer: + while True: + command = pipe.recv() + assert isinstance(command, str) + if command == "stop": + break + elif command == "step": + pass + else: + raise ValueError(f"Unknown command: {command}") + + kwargs = {"input_pos": input_pos, "cache_lane": lane} + # sendrecv between last and first ranks, only if: + # first_pp_rank != last_pp_rank. + if pp_rank == last_pp_rank and pp_rank != first_pp_rank: + dist.send( + new_token, + dst=first_pp_rank_global_id, + group=pp_group, + ) + elif pp_rank == first_pp_rank and pp_rank != last_pp_rank: + dist.recv( + new_token, + src=last_pp_rank_global_id, + group=pp_group, + ) + + # Run data through pipeline + if pp_rank == first_pp_rank: + output = decoder.step(new_token, **kwargs) + elif pp_rank == last_pp_rank: + output = decoder.step(**kwargs) + else: # middle pp ranks + decoder.step(**kwargs) + + # Decode the output + if pp_rank == last_pp_rank: + new_token = _batch_decode_next_tokens(output) + res.append(new_token) + #TODO: Move to a separate decoding thread + resp = _decode_in_flight(new_token, tokenizer, tp_rank) + pipe.send(resp) + else: + pipe.send(None) + + # Increment input position + input_pos += 1 - # Display the decoding results + logger.info( + f"{color.green}Decoding time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" + ) - # output formatted response via last pp group and tp rank 0 - if pp_rank == last_pp_rank and tp_rank == 0: - # `res` is a list of tensors, each being a batch of generated token ids. - # We need to concatenate them to get the full sequence of generated - # token ids. Thus cat'ing along dim 1. - res = torch.cat(res, dim=1) - res_list = res.tolist() + # Display the decoding results - responses = tokenizer.decode(res_list) + # output formatted response via last pp group and tp rank 0 + if pp_rank == last_pp_rank and tp_rank == 0: + # `res` is a list of tensors, each being a batch of generated token ids. + # We need to concatenate them to get the full sequence of generated + # token ids. Thus cat'ing along dim 1. + res = torch.cat(res, dim=1) + res_list = res.tolist() - # Show prompts and responses - for prompt_text, response_text in zip(prompt, responses): - logger.info(f"Prompt: {color.green}{prompt_text} {color.reset}") - logger.info(f"Response: {color.red}{response_text} {color.reset}") + responses = tokenizer.decode(res_list) + # Show prompts and responses + for prompt_text, response_text in zip(prompt, responses): + logger.info(f"Prompt: {color.green}{prompt_text} {color.reset}") + logger.info(f"Response: {color.red}{response_text} {color.reset}") + # Cleanup _cleanup() logger.info( @@ -603,12 +629,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: default=40, help="Number of tokens to generate", ) - parser.add_argument( - "--disable-in-flight-decode", - action="store_true", - default=False, - help="Whether to decode token into string in flight", - ) parser.add_argument( "--chpt-from", type=str, diff --git a/torchchat/distributed/generate.py b/torchchat/distributed/generate.py index 7bec30db5..10fa5cbba 100644 --- a/torchchat/distributed/generate.py +++ b/torchchat/distributed/generate.py @@ -7,32 +7,25 @@ from typing import List, Optional from dataclasses import dataclass from pathlib import Path +from os import environ from torchchat.cli.builder import BuilderArgs, TokenizerArgs +from functools import partial - +import atexit +import torch.multiprocessing as mp import importlib.util import subprocess -def run_script(script_path, *args): - # Construct the command to run the script - cmd = [sys.executable, script_path] + list(args) - - # Run the script as a subprocess - process = subprocess.Popen( - cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True - ) - - # Stream the output in real-time - for line in process.stdout: - print(line, end="") - for line in process.stderr: - print(line, end="", file=sys.stderr) +def _setup_env(world_size:int, rank:int, target: callable, *args, **kwargs): + environ["MASTER_ADDR"] = "localhost" + environ["MASTER_PORT"] = "29500" + environ["RDZV_BACKEND"] = "c10d" + environ["WORLD_SIZE"] = str(world_size) + environ["RANK"] = str(rank) + environ["LOCALRANK"] = str(rank) - # Wait for the process to complete and get the return code - return_code = process.wait() - if return_code != 0: - raise subprocess.CalledProcessError(return_code, cmd) + return target(*args, **kwargs) def _launch_distributed_inference(builder_args: BuilderArgs) -> None: @@ -41,40 +34,29 @@ def _launch_distributed_inference(builder_args: BuilderArgs) -> None: num_processes_per_node = 4 # builder_args.num_gpus + 1 - lc = launcher.LaunchConfig( - min_nodes=1, - max_nodes=1, - nproc_per_node=num_processes_per_node, - # run_id=str(uuid.uuid4()), - rdzv_backend="c10d", - rdzv_endpoint="localhost:29401", - max_restarts=0, - monitor_interval=1, - ) - - # train_file_path = Path(__file__).parent.parent.parent / "dist_run.py" - # print(f"train_file_path: {train_file_path}") - # import argparse + from torchchat.distributed.dist_run import main + mp.set_start_method('spawn') - # parser2 = argparse.ArgumentParser() + pipes = [] + procs = [] + for rank in range(num_processes_per_node): + server_pipe, client_pipe = mp.Pipe(duplex=True) + pipes.append(server_pipe) + proc = mp.Process( + target=partial(_setup_env, num_processes_per_node, rank, main), + args=(builder_args, client_pipe) + ) + proc.start() - # args = parser2.parse_args() - args = [] - print(f"args: {args}") - from dist_run import main + for pipe in pipes: + response = pipe.recv() + print(f"Received: {response=}") - elastic_launch( - config=lc, - entrypoint=run_script, - )(main, *args) print( f"Done launching distributed inference on **4 ** {builder_args.num_gpus} GPUs." ) - # role=role, *args, **kwargs) - - # assert False, "distributed inference is not supported yet" - # pass + return procs, pipes @dataclass class Output: @@ -104,25 +86,58 @@ def __init__( quantize: bool, draft_quantize: bool, ): + self.builder_args = builder_args self.requests = {} + self.in_flight_requests = {} + # For now we have a static batch order we save separately + self.in_flight_batch_order = [] # if builder_args.distributed: # # we part ways here with torchchat cli and move into dist inference - _launch_distributed_inference(builder_args) - # return None + self.procs, self.pipes = _launch_distributed_inference(builder_args) + self.current_step = 0 + + atexit.register(self.shutdown) + def shutdown(self): + for p in self.pipes: + p.send("stop") + for p in self.procs: + p.kill() + #TODO: Replace against (async) generate def add_request(self, request_id: int, prompt: str): assert request_id not in self.requests self.requests[request_id] = prompt def step(self) -> List[Output]: + responses = [] + #TODO: Implement a scheduler to handle the requests + if len(self.in_flight_requests) > 0: + #Receive decoded token + for p in self.pipes: + p.send("step") + for p in self.pipes: + responses.append(p.recv()) + + else: + # Send requests to backend + self.in_flight_batch_order = list(self.requests.keys()) + prompts = [self.requests[k] for k in self.in_flight_batch_order] + for p in self.pipes: + p.send(prompts) + self.in_flight_requests = self.requests + self.requests = {} + self.current_step = 0 + #Receive first token + for p in self.pipes: + responses.append(p.recv()) + + responses = responses[0] outputs = [] - for request_id, prompt in self.requests.items(): - outputs.append(Output(request_id, is_finished=True, output=prompt)) + for k, v in zip(self.in_flight_batch_order, responses): + outputs.append(Output(k, is_finished=self.current_step>=self.builder_args.ntokens, output=v)) - for output in outputs: - if output.is_finished: - del self.requests[output.request_id] + self.current_step += 1 return outputs diff --git a/torchchat/generate.py b/torchchat/generate.py index c1faeb830..37478c03a 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -1245,8 +1245,20 @@ def main(args): dist_gen.add_request(0, "Tell me a joke") dist_gen.add_request(1, "Tell me another joke") + dist_gen.add_request(2, "Who is this Santa") + dist_gen.add_request(3, "What did the fish say to the duck") - outputs = dist_gen.step() - while len(outputs): - print(outputs) + responses = {} + + running = True + while running: outputs = dist_gen.step() + for o in outputs: + responses[o.request_id] = responses.get(o.request_id, "") + o.output + running &= not o.is_finished + + print(responses) + + dist_gen.shutdown() + + From adcf232ba9b890934b7f1695443217452e72c680 Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Tue, 22 Oct 2024 21:26:35 -0700 Subject: [PATCH 14/24] Added generate method and placeholder scheduler --- torchchat/cli/builder.py | 3 - torchchat/cli/cli.py | 7 -- torchchat/distributed/dist_run.py | 12 +-- torchchat/distributed/generate.py | 169 ++++++++++++++++++++++-------- torchchat/generate.py | 22 +--- 5 files changed, 134 insertions(+), 79 deletions(-) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 11b465e2a..82fefb834 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -63,7 +63,6 @@ class BuilderArgs: pp: int = 1 tp: int = 1 chpt_from: str = "hf" - ntokens: int = 40 is_chat_model: bool = False prefill_possible: bool = False dynamic_shapes: bool = False @@ -171,7 +170,6 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": pp = getattr(args, "pp", 1) tp = getattr(args, "tp", 1) chpt_from = getattr(args, "chpt_from", "hf") - ntokens = getattr(args, "ntokens", 40) return cls( checkpoint_dir=checkpoint_dir, checkpoint_path=checkpoint_path, @@ -191,7 +189,6 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": pp=pp, tp=tp, chpt_from=chpt_from, - ntokens=ntokens, is_chat_model=is_chat_model, dynamic_shapes=getattr(args, "dynamic_shapes", False), max_seq_length=getattr(args, "max_seq_length", None), diff --git a/torchchat/cli/cli.py b/torchchat/cli/cli.py index dad073a14..60e09eeac 100644 --- a/torchchat/cli/cli.py +++ b/torchchat/cli/cli.py @@ -425,13 +425,6 @@ def _add_distributed_args(parser) -> None: help=argparse.SUPPRESS, # "Tensor parallel degree", ) - - parser.add_argument( - "--ntokens", - type=int, - default=40, - help="Number of tokens to generate", - ) parser.add_argument( "--chpt-from", type=str, diff --git a/torchchat/distributed/dist_run.py b/torchchat/distributed/dist_run.py index 6a74e95d1..21438264d 100644 --- a/torchchat/distributed/dist_run.py +++ b/torchchat/distributed/dist_run.py @@ -388,7 +388,7 @@ def main(args, pipe): # Batch size. Since we push batches dynamically through the pipeline rather # than chunking them, this is effectively micro-batch size in pipeline # sense. Thus it is interchangeable with micro-batch size below. - batch_size = 4# len(prompt) + batch_size = 1# len(prompt) seqlen_prefill = 1024 # sequence length dim = 4096 # embedding dimension @@ -410,9 +410,6 @@ def main(args, pipe): logger.info( f"Stage {rank} has {color.blue}{stage_num_params} params{color.reset}, Size: {color.blue}{stage_size_formatted}{color.reset}" ) - - # Setup input position (input_pos) for prefill: a list of increasing integers from 0 to seqlen - input_pos = torch.arange(seqlen_prefill, device=device) model.eval() # Helper function to get example inputs and outputs for the stages. @@ -470,6 +467,8 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: logger.info(f"{color.green}Prompt: {prompt}{color.reset}") start_pos = 0 + # Setup input position (input_pos) for prefill: a list of increasing integers from 0 to seqlen + input_pos = torch.arange(seqlen_prefill, device=device) # encode the prompt input_ids = _encode_strings( @@ -511,9 +510,8 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: res.append(new_token) #TODO: Move to a separate decoding thread resp = _decode_in_flight(new_token, tokenizer, tp_rank) - pipe.send(resp) + pipe.send((resp, new_token.tolist())) else: - logger.info(f"sending None {tp_rank=}") pipe.send(None) # seqlen = 1 now @@ -577,7 +575,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: res.append(new_token) #TODO: Move to a separate decoding thread resp = _decode_in_flight(new_token, tokenizer, tp_rank) - pipe.send(resp) + pipe.send((resp, new_token)) else: pipe.send(None) diff --git a/torchchat/distributed/generate.py b/torchchat/distributed/generate.py index 10fa5cbba..af4ca028a 100644 --- a/torchchat/distributed/generate.py +++ b/torchchat/distributed/generate.py @@ -4,15 +4,19 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from abc import abstractmethod -from typing import List, Optional +from collections import deque from dataclasses import dataclass -from pathlib import Path +from functools import partial from os import environ +from pathlib import Path from torchchat.cli.builder import BuilderArgs, TokenizerArgs -from functools import partial +from typing import List, Optional +from uuid import uuid4 +import asyncio import atexit import torch.multiprocessing as mp +import threading import importlib.util import subprocess @@ -51,7 +55,6 @@ def _launch_distributed_inference(builder_args: BuilderArgs) -> None: for pipe in pipes: response = pipe.recv() - print(f"Received: {response=}") print( f"Done launching distributed inference on **4 ** {builder_args.num_gpus} GPUs." @@ -60,56 +63,72 @@ def _launch_distributed_inference(builder_args: BuilderArgs) -> None: @dataclass class Output: - request_id: int is_finished: bool = False - output: Optional[str] = None - -class Generator(object): + text: Optional[str] = None + token: Optional[list] = None - @abstractmethod - def add_request(self, request_id: int, prompt: str): - raise NotImplementedError() +@dataclass +class Request: + request_id: int + prompt: str - def step(self) -> List[Output]: - raise NotImplementedError() + @classmethod + def new_request(cls, prompt): + return cls(request_id=uuid4().int, prompt=prompt) -class DistributedGenerator(Generator): +class Scheduler(object): def __init__( self, - builder_args: BuilderArgs, - speculative_builder_args: BuilderArgs, - tokenizer_args: TokenizerArgs, - #TODO: move GeneratorArgs into a different module - # generator_args: GeneratorArgs, - profile: Optional[Path], - quantize: bool, - draft_quantize: bool, + builder_args, + generator_args, + pipes, + loop, ): self.builder_args = builder_args + self.generator_args = generator_args self.requests = {} self.in_flight_requests = {} - # For now we have a static batch order we save separately self.in_flight_batch_order = [] - # if builder_args.distributed: - # # we part ways here with torchchat cli and move into dist inference - self.procs, self.pipes = _launch_distributed_inference(builder_args) - self.current_step = 0 - - atexit.register(self.shutdown) - - def shutdown(self): - for p in self.pipes: - p.send("stop") - for p in self.procs: - p.kill() - - #TODO: Replace against (async) generate - def add_request(self, request_id: int, prompt: str): - assert request_id not in self.requests - self.requests[request_id] = prompt - - + self.pipes = pipes + self.req_to_states = {} + self.req_to_results = {} + self.request_queue = mp.Queue() + self.loop = loop + + def schedule_request(self, req: Request): + self.req_to_states[req.request_id] = asyncio.Event() + self.req_to_results[req.request_id] = deque() + self.request_queue.put(req) + + def process_requests_loop(self): + while True: + req = self.request_queue.get() + if req == "stop": + break + self.requests = {req.request_id: req.prompt} + + responses = {} + running = True + while running: + outputs = self.step() + self.req_to_results[req.request_id].append(outputs[0]) + + self.loop.call_soon_threadsafe(self.req_to_states[req.request_id].set) + + running &= not outputs[0].is_finished + + async def wait_for_request(self, req: Request) -> Output: + is_finished = False + while not is_finished: + await self.req_to_states[req.request_id].wait() + while len(self.req_to_results[req.request_id]): + output = self.req_to_results[req.request_id].popleft() + is_finished |= output.is_finished + yield output + del self.req_to_states[req.request_id] + del self.req_to_results[req.request_id] + def step(self) -> List[Output]: responses = [] #TODO: Implement a scheduler to handle the requests @@ -132,12 +151,72 @@ def step(self) -> List[Output]: #Receive first token for p in self.pipes: responses.append(p.recv()) - responses = responses[0] outputs = [] - for k, v in zip(self.in_flight_batch_order, responses): - outputs.append(Output(k, is_finished=self.current_step>=self.builder_args.ntokens, output=v)) + for k, v in zip(self.in_flight_batch_order, zip(responses[0], responses[1])): + text, token_ids = v + outputs.append( + Output( + is_finished=self.current_step>=self.generator_args.max_new_tokens, + text=text, + token=token_ids, + ) + ) + if self.current_step >= self.generator_args.max_new_tokens: + for p in self.pipes: + p.send("stop") + self.in_flight_requests = [] self.current_step += 1 return outputs + + +class DistributedGenerator(object): + def __init__( + self, + builder_args: BuilderArgs, + tokenizer_args: TokenizerArgs, + #TODO: move GeneratorArgs into a different module + generator_args, + profile: Optional[Path], + quantize: bool, + draft_quantize: bool, + ): + self.builder_args = builder_args + self.generate_args = generator_args + + self.procs, self.pipes = _launch_distributed_inference(builder_args) + + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + self.scheduler = Scheduler(builder_args, generator_args, self.pipes, self.loop) + + #TODO: Mode into process and use pipe or queue for comm + self.scheduler_thread = threading.Thread(target=self.scheduler.process_requests_loop) + self.scheduler_thread.start() + + atexit.register(self.shutdown) + + def shutdown(self): + self.scheduler.request_queue.put("stop") + self.scheduler_thread.join() + + for p in self.pipes: + p.send("stop") + for p in self.procs: + p.kill() + + def generate(self, text): + req = Request.new_request(text) + self.scheduler.schedule_request(req) + + generator = self.scheduler.wait_for_request(req) + + running = True + while running: + output = self.loop.run_until_complete(generator.__anext__()) + running &= not output.is_finished + + yield output diff --git a/torchchat/generate.py b/torchchat/generate.py index 37478c03a..ef31a359f 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -1235,30 +1235,18 @@ def main(args): else: dist_gen = DistributedGenerator( builder_args, - speculative_builder_args, tokenizer_args, - # generator_args, + generator_args, args.profile, args.quantize, args.draft_quantize, ) - dist_gen.add_request(0, "Tell me a joke") - dist_gen.add_request(1, "Tell me another joke") - dist_gen.add_request(2, "Who is this Santa") - dist_gen.add_request(3, "What did the fish say to the duck") - - responses = {} + response = "" + for output in dist_gen.generate("Tell me a joke"): + response += output.text - running = True - while running: - outputs = dist_gen.step() - for o in outputs: - responses[o.request_id] = responses.get(o.request_id, "") + o.output - running &= not o.is_finished - - print(responses) - + print(f"Model output: {response}") dist_gen.shutdown() From 3836928671f681f3b08037b455ce1e75ecf8ead4 Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Wed, 23 Oct 2024 16:42:55 -0700 Subject: [PATCH 15/24] use prompt parameter for dist generation --- torchchat/cli/builder.py | 6 --- torchchat/cli/cli.py | 4 +- torchchat/distributed/generate.py | 71 ++++++++++++++++++------------- torchchat/generate.py | 2 +- 4 files changed, 44 insertions(+), 39 deletions(-) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 82fefb834..8be70e611 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -58,8 +58,6 @@ class BuilderArgs: precision: torch.dtype = torch.float32 setup_caches: bool = False distributed: bool = False - num_gpus: int = 1 - num_nodes: int = 1 pp: int = 1 tp: int = 1 chpt_from: str = "hf" @@ -165,8 +163,6 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": dtype = name_to_dtype(args.dtype, args.device) # distributed args distributed = getattr(args, "distributed", False) - num_gpus = getattr(args, "num_gpus", 1) - num_nodes = getattr(args, "num_nodes", 1) pp = getattr(args, "pp", 1) tp = getattr(args, "tp", 1) chpt_from = getattr(args, "chpt_from", "hf") @@ -184,8 +180,6 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": precision=dtype, setup_caches=(output_dso_path or output_pte_path), distributed=distributed, - num_gpus=num_gpus, - num_nodes=num_nodes, pp=pp, tp=tp, chpt_from=chpt_from, diff --git a/torchchat/cli/cli.py b/torchchat/cli/cli.py index 60e09eeac..6f9f2261f 100644 --- a/torchchat/cli/cli.py +++ b/torchchat/cli/cli.py @@ -448,13 +448,13 @@ def _add_custom_model_args(parser) -> None: "--params-path", type=Path, default=None, - help= "Use the specified parameter file, instead of one specified under torchchat.model_params", + help="Use the specified parameter file, instead of one specified under torchchat.model_params", ) parser.add_argument( "--tokenizer-path", type=Path, default=None, - help= "Use the specified model tokenizer file, instead of the one downloaded from HuggingFace", + help="Use the specified model tokenizer file, instead of the one downloaded from HuggingFace", ) diff --git a/torchchat/distributed/generate.py b/torchchat/distributed/generate.py index af4ca028a..39b677030 100644 --- a/torchchat/distributed/generate.py +++ b/torchchat/distributed/generate.py @@ -3,25 +3,25 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import asyncio +import atexit +import importlib.util +import subprocess +import threading from abc import abstractmethod from collections import deque from dataclasses import dataclass from functools import partial from os import environ from pathlib import Path -from torchchat.cli.builder import BuilderArgs, TokenizerArgs from typing import List, Optional from uuid import uuid4 -import asyncio -import atexit import torch.multiprocessing as mp -import threading -import importlib.util -import subprocess +from torchchat.cli.builder import BuilderArgs, TokenizerArgs -def _setup_env(world_size:int, rank:int, target: callable, *args, **kwargs): +def _setup_env(world_size: int, rank: int, target: callable, *args, **kwargs): environ["MASTER_ADDR"] = "localhost" environ["MASTER_PORT"] = "29500" environ["RDZV_BACKEND"] = "c10d" @@ -36,10 +36,11 @@ def _launch_distributed_inference(builder_args: BuilderArgs) -> None: # create programmatic elastic launch print("Launching distributed inference ...") - num_processes_per_node = 4 # builder_args.num_gpus + 1 + num_processes_per_node = builder_args.pp * builder_args.tp from torchchat.distributed.dist_run import main - mp.set_start_method('spawn') + + mp.set_start_method("spawn") pipes = [] procs = [] @@ -48,25 +49,24 @@ def _launch_distributed_inference(builder_args: BuilderArgs) -> None: pipes.append(server_pipe) proc = mp.Process( target=partial(_setup_env, num_processes_per_node, rank, main), - args=(builder_args, client_pipe) + args=(builder_args, client_pipe), ) proc.start() - for pipe in pipes: response = pipe.recv() - print( - f"Done launching distributed inference on **4 ** {builder_args.num_gpus} GPUs." - ) + print(f"Done launching distributed inference on {num_processes_per_node} GPUs.") return procs, pipes + @dataclass class Output: is_finished: bool = False text: Optional[str] = None token: Optional[list] = None + @dataclass class Request: request_id: int @@ -84,7 +84,7 @@ def __init__( generator_args, pipes, loop, - ): + ): self.builder_args = builder_args self.generator_args = generator_args self.requests = {} @@ -107,7 +107,7 @@ def process_requests_loop(self): if req == "stop": break self.requests = {req.request_id: req.prompt} - + responses = {} running = True while running: @@ -128,17 +128,17 @@ async def wait_for_request(self, req: Request) -> Output: yield output del self.req_to_states[req.request_id] del self.req_to_results[req.request_id] - + def step(self) -> List[Output]: responses = [] - #TODO: Implement a scheduler to handle the requests + # TODO: Implement a scheduler to handle the requests if len(self.in_flight_requests) > 0: - #Receive decoded token + # Receive decoded token for p in self.pipes: p.send("step") for p in self.pipes: responses.append(p.recv()) - + else: # Send requests to backend self.in_flight_batch_order = list(self.requests.keys()) @@ -148,25 +148,26 @@ def step(self) -> List[Output]: self.in_flight_requests = self.requests self.requests = {} self.current_step = 0 - #Receive first token + # Receive first token for p in self.pipes: responses.append(p.recv()) - responses = responses[0] + # Filter out None responses from in-between stages + responses = [r for r in responses if r is not None][0] outputs = [] for k, v in zip(self.in_flight_batch_order, zip(responses[0], responses[1])): text, token_ids = v outputs.append( Output( - is_finished=self.current_step>=self.generator_args.max_new_tokens, + is_finished=self.current_step >= self.generator_args.max_new_tokens, text=text, token=token_ids, - ) ) + ) if self.current_step >= self.generator_args.max_new_tokens: for p in self.pipes: p.send("stop") self.in_flight_requests = [] - + self.current_step += 1 return outputs @@ -177,15 +178,17 @@ def __init__( self, builder_args: BuilderArgs, tokenizer_args: TokenizerArgs, - #TODO: move GeneratorArgs into a different module + # TODO: move GeneratorArgs into a different module generator_args, profile: Optional[Path], quantize: bool, draft_quantize: bool, - ): + ): self.builder_args = builder_args self.generate_args = generator_args - + + self.check_args() + self.procs, self.pipes = _launch_distributed_inference(builder_args) self.loop = asyncio.new_event_loop() @@ -193,8 +196,10 @@ def __init__( self.scheduler = Scheduler(builder_args, generator_args, self.pipes, self.loop) - #TODO: Mode into process and use pipe or queue for comm - self.scheduler_thread = threading.Thread(target=self.scheduler.process_requests_loop) + # TODO: Mode into process and use pipe or queue for comm + self.scheduler_thread = threading.Thread( + target=self.scheduler.process_requests_loop + ) self.scheduler_thread.start() atexit.register(self.shutdown) @@ -220,3 +225,9 @@ def generate(self, text): running &= not output.is_finished yield output + + def check_args(self): + if self.generate_args.chat_mode: + raise NotImplementedError( + "Currently we only support generate with --distributed" + ) diff --git a/torchchat/generate.py b/torchchat/generate.py index ef31a359f..812763ad8 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -1243,7 +1243,7 @@ def main(args): ) response = "" - for output in dist_gen.generate("Tell me a joke"): + for output in dist_gen.generate(generator_args.prompt): response += output.text print(f"Model output: {response}") From 3f6fa2d4f25e102f47704ea0034a1e28f3cdb42d Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Thu, 24 Oct 2024 10:32:09 -0700 Subject: [PATCH 16/24] Enforce tp>=2 --- torchchat/cli/cli.py | 11 ++++------- torchchat/distributed/dist_run.py | 32 +++++++++++++++++++------------ torchchat/distributed/generate.py | 2 ++ 3 files changed, 26 insertions(+), 19 deletions(-) diff --git a/torchchat/cli/cli.py b/torchchat/cli/cli.py index 6f9f2261f..bc41d56ec 100644 --- a/torchchat/cli/cli.py +++ b/torchchat/cli/cli.py @@ -399,8 +399,7 @@ def _add_distributed_args(parser) -> None: parser.add_argument( "--distributed", action="store_true", - help=argparse.SUPPRESS, - # "Whether to enable distributed inference", + help="Whether to enable distributed inference", ) parser.add_argument( "--dcp-dir", @@ -414,16 +413,14 @@ def _add_distributed_args(parser) -> None: "--pipeline-parallel", type=int, default=1, - help=argparse.SUPPRESS, - # "Pipeline parallel degree", + help="Pipeline parallel degree", ) parser.add_argument( "--tp", "--tensor-parallel", type=int, - default=1, - help=argparse.SUPPRESS, - # "Tensor parallel degree", + default=2, + help="Tensor parallel degree", ) parser.add_argument( "--chpt-from", diff --git a/torchchat/distributed/dist_run.py b/torchchat/distributed/dist_run.py index 21438264d..93313f3d2 100644 --- a/torchchat/distributed/dist_run.py +++ b/torchchat/distributed/dist_run.py @@ -12,7 +12,7 @@ import os from enum import auto, Enum from pathlib import Path -from types import SimpleNamespace, MethodType +from types import MethodType, SimpleNamespace from typing import Any, Dict, List, Optional, Tuple import torch @@ -71,21 +71,26 @@ def _init_distributed(): def _create_device_mesh(pp_degree, tp_degree): - return dist.init_device_mesh("cuda", (pp_degree, tp_degree), mesh_dim_names=("pp", "tp")) + return dist.init_device_mesh( + "cuda", (pp_degree, tp_degree), mesh_dim_names=("pp", "tp") + ) def dict_to_args(dictionary: Dict[str, Any]) -> SimpleNamespace: return SimpleNamespace(**dictionary) + def _patch_tokenizer(tokenizer): """Patch the tokenizer to support decoding of token ids.""" if isinstance(tokenizer, TiktokenTokenizer): # Patch tiktokenizer to allow a list of sequences. - #TODO: Upstream to tokenizer modules + # TODO: Upstream to tokenizer modules old_decode = tokenizer.decode - def decode(self, token_ids: List[int|List[int]], *args, **kwargs) -> str | List[str]: - if len(token_ids)<1: + def decode( + self, token_ids: List[int | List[int]], *args, **kwargs + ) -> str | List[str]: + if len(token_ids) < 1: return "" if isinstance(token_ids[0], list): return [old_decode(t, *args, **kwargs) for t in token_ids] @@ -95,6 +100,7 @@ def decode(self, token_ids: List[int|List[int]], *args, **kwargs) -> str | List[ tokenizer.decode = MethodType(decode, tokenizer) return tokenizer + def _build_chat_tokenizer( model_name: str, model_base_name: Optional[str] = None, @@ -221,7 +227,7 @@ def _create_padded_prompts( def _batch_decode_next_tokens( output: torch.Tensor, - pos: List[int]=None, + pos: List[int] = None, temperature: float = 1.0, topk: int = 10, ) -> torch.Tensor: @@ -388,7 +394,7 @@ def main(args, pipe): # Batch size. Since we push batches dynamically through the pipeline rather # than chunking them, this is effectively micro-batch size in pipeline # sense. Thus it is interchangeable with micro-batch size below. - batch_size = 1# len(prompt) + batch_size = 1 # len(prompt) seqlen_prefill = 1024 # sequence length dim = 4096 # embedding dimension @@ -463,7 +469,9 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: raise ValueError(f"Unknown command: {command}") else: prompt = command - assert len(prompt) == batch_size, f"Expecting {batch_size=} prompts but got {len(prompt)=}" + assert ( + len(prompt) == batch_size + ), f"Expecting {batch_size=} prompts but got {len(prompt)=}" logger.info(f"{color.green}Prompt: {prompt}{color.reset}") start_pos = 0 @@ -508,7 +516,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: logger.info(f"{color.green}Decoding...{prompt_lengths=}{color.reset}") new_token = _batch_decode_next_tokens(output, prompt_lengths) res.append(new_token) - #TODO: Move to a separate decoding thread + # TODO: Move to a separate decoding thread resp = _decode_in_flight(new_token, tokenizer, tp_rank) pipe.send((resp, new_token.tolist())) else: @@ -539,7 +547,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: command = pipe.recv() assert isinstance(command, str) if command == "stop": - break + break elif command == "step": pass else: @@ -573,7 +581,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: if pp_rank == last_pp_rank: new_token = _batch_decode_next_tokens(output) res.append(new_token) - #TODO: Move to a separate decoding thread + # TODO: Move to a separate decoding thread resp = _decode_in_flight(new_token, tokenizer, tp_rank) pipe.send((resp, new_token)) else: @@ -602,7 +610,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: for prompt_text, response_text in zip(prompt, responses): logger.info(f"Prompt: {color.green}{prompt_text} {color.reset}") logger.info(f"Response: {color.red}{response_text} {color.reset}") - + # Cleanup _cleanup() logger.info( diff --git a/torchchat/distributed/generate.py b/torchchat/distributed/generate.py index 39b677030..94264ca99 100644 --- a/torchchat/distributed/generate.py +++ b/torchchat/distributed/generate.py @@ -231,3 +231,5 @@ def check_args(self): raise NotImplementedError( "Currently we only support generate with --distributed" ) + elif self.builder_args.tp < 2: + raise RuntimeError("TP degree must be at least 2 for distributed inference") From fd9f70421c47847a4cc984711339a644461d0dce Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Thu, 24 Oct 2024 11:22:46 -0700 Subject: [PATCH 17/24] Build tokenizer from TokenizerArgs --- torchchat/distributed/dist_run.py | 40 ++++++++++--------------------- torchchat/distributed/generate.py | 10 +++++--- torchchat/generate.py | 5 +--- 3 files changed, 20 insertions(+), 35 deletions(-) diff --git a/torchchat/distributed/dist_run.py b/torchchat/distributed/dist_run.py index 93313f3d2..276303c38 100644 --- a/torchchat/distributed/dist_run.py +++ b/torchchat/distributed/dist_run.py @@ -102,32 +102,11 @@ def decode( def _build_chat_tokenizer( - model_name: str, - model_base_name: Optional[str] = None, + tokenizer_args: TokenizerArgs, ) -> SentencePieceProcessor | TiktokenTokenizer: """Builds a tokenizer for the given model name""" - - # Try to infer the model base name from the model name: - # e.g. "llama2-7b-chat" -> "llama2" - if model_base_name is None: - model_base_name = model_name.split("-")[0] - logger.info( - f"Using model base name '{model_base_name}' to build tokenizer. " - "If not found, please specify it using the `model_base_name` argument." - ) - - # Create base args for tokenizer - default_model_dir = Path( - os.getenv("TORCHCHAT_MODELDIR", "~/.torchchat/model-cache") - ).expanduser() - - tokenconfig = { - "model_directory": default_model_dir, - "model": model_base_name, - "tokenizer_path": None, - } - args = dict_to_args(tokenconfig) - tokenizer_args = TokenizerArgs.from_args(args) + + tokenizer_args = TokenizerArgs.from_args(tokenizer_args) tokenizer = tokenizer_args.t assert tokenizer is not None, f"Failed to get tokenizer using {tokenconfig=}" logger.info( @@ -313,9 +292,14 @@ def _cleanup(): ] -def main(args, pipe): +def main( + builder_args, + tokenizer_args, + pipe, +): model_name = "llama3" # args.model_name - pp_degree = args.pp + # print(f"{builder_args.checkpoint_path=}") + pp_degree = builder_args.pp rank, world_size = _init_distributed() logger.info(f"Worker started: {rank=}, {world_size=}") @@ -332,7 +316,7 @@ def main(args, pipe): config = TransformerArgs.from_params(model_config.transformer_args["text"]) logger.info(f"Transformer Config: {config}") - tokenizer = _build_chat_tokenizer(model_name) + tokenizer = _build_chat_tokenizer(tokenizer_args) set_precision(model_dtype) logger.info(f"Using cache precision {model_dtype}") @@ -385,7 +369,7 @@ def main(args, pipe): # Load weights logger.info(f"Loading weights for {pp_rank=} on {device=}") with CUDATrackTime() as timer: - _load_model_weights(model, distribution, device, config, args.chpt_from) + _load_model_weights(model, distribution, device, config, builder_args.chpt_from) logger.info( f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" diff --git a/torchchat/distributed/generate.py b/torchchat/distributed/generate.py index 94264ca99..16dc1b248 100644 --- a/torchchat/distributed/generate.py +++ b/torchchat/distributed/generate.py @@ -32,7 +32,9 @@ def _setup_env(world_size: int, rank: int, target: callable, *args, **kwargs): return target(*args, **kwargs) -def _launch_distributed_inference(builder_args: BuilderArgs) -> None: +def _launch_distributed_inference( + builder_args: BuilderArgs, tokenizer_args: TokenizerArgs +) -> tuple[List]: # create programmatic elastic launch print("Launching distributed inference ...") @@ -49,7 +51,7 @@ def _launch_distributed_inference(builder_args: BuilderArgs) -> None: pipes.append(server_pipe) proc = mp.Process( target=partial(_setup_env, num_processes_per_node, rank, main), - args=(builder_args, client_pipe), + args=(builder_args, tokenizer_args, client_pipe), ) proc.start() @@ -189,7 +191,9 @@ def __init__( self.check_args() - self.procs, self.pipes = _launch_distributed_inference(builder_args) + self.procs, self.pipes = _launch_distributed_inference( + builder_args, tokenizer_args + ) self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) diff --git a/torchchat/generate.py b/torchchat/generate.py index 812763ad8..6e9e15585 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -30,8 +30,8 @@ BuilderArgs, TokenizerArgs, ) -from torchchat.model import Model, ModelType from torchchat.distributed.generate import DistributedGenerator +from torchchat.model import Model, ModelType from torchchat.utils.build_utils import device_sync, set_precision from torchchat.utils.device_info import get_device_info @@ -1228,7 +1228,6 @@ def main(args): ) if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats() - for _ in gen.chat(generator_args): pass @@ -1248,5 +1247,3 @@ def main(args): print(f"Model output: {response}") dist_gen.shutdown() - - From e8f7c987ffa380e7307ead95b2d826fe25759b12 Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Thu, 24 Oct 2024 12:35:49 -0700 Subject: [PATCH 18/24] Disable torchchat format + constrain possible models for distributed --- torchchat/distributed/dist_run.py | 5 ++--- torchchat/distributed/generate.py | 20 ++++++++++++++++---- torchchat/generate.py | 1 + 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/torchchat/distributed/dist_run.py b/torchchat/distributed/dist_run.py index 276303c38..ddef3efe3 100644 --- a/torchchat/distributed/dist_run.py +++ b/torchchat/distributed/dist_run.py @@ -105,7 +105,7 @@ def _build_chat_tokenizer( tokenizer_args: TokenizerArgs, ) -> SentencePieceProcessor | TiktokenTokenizer: """Builds a tokenizer for the given model name""" - + tokenizer_args = TokenizerArgs.from_args(tokenizer_args) tokenizer = tokenizer_args.t assert tokenizer is not None, f"Failed to get tokenizer using {tokenconfig=}" @@ -293,12 +293,11 @@ def _cleanup(): def main( + model_name, builder_args, tokenizer_args, pipe, ): - model_name = "llama3" # args.model_name - # print(f"{builder_args.checkpoint_path=}") pp_degree = builder_args.pp rank, world_size = _init_distributed() diff --git a/torchchat/distributed/generate.py b/torchchat/distributed/generate.py index 16dc1b248..0abaebf29 100644 --- a/torchchat/distributed/generate.py +++ b/torchchat/distributed/generate.py @@ -19,6 +19,7 @@ import torch.multiprocessing as mp from torchchat.cli.builder import BuilderArgs, TokenizerArgs +from torchchat.distributed.dist_run import NAME_TO_DISTRIBUTION_AND_DTYPE def _setup_env(world_size: int, rank: int, target: callable, *args, **kwargs): @@ -33,7 +34,7 @@ def _setup_env(world_size: int, rank: int, target: callable, *args, **kwargs): def _launch_distributed_inference( - builder_args: BuilderArgs, tokenizer_args: TokenizerArgs + model_name: str, builder_args: BuilderArgs, tokenizer_args: TokenizerArgs ) -> tuple[List]: # create programmatic elastic launch print("Launching distributed inference ...") @@ -51,7 +52,7 @@ def _launch_distributed_inference( pipes.append(server_pipe) proc = mp.Process( target=partial(_setup_env, num_processes_per_node, rank, main), - args=(builder_args, tokenizer_args, client_pipe), + args=(model_name, builder_args, tokenizer_args, client_pipe), ) proc.start() @@ -178,6 +179,8 @@ def step(self) -> List[Output]: class DistributedGenerator(object): def __init__( self, + # TODO: switch this to torchchat method + model_name: str, builder_args: BuilderArgs, tokenizer_args: TokenizerArgs, # TODO: move GeneratorArgs into a different module @@ -186,13 +189,14 @@ def __init__( quantize: bool, draft_quantize: bool, ): + self.model_name = model_name self.builder_args = builder_args self.generate_args = generator_args self.check_args() self.procs, self.pipes = _launch_distributed_inference( - builder_args, tokenizer_args + model_name, builder_args, tokenizer_args ) self.loop = asyncio.new_event_loop() @@ -236,4 +240,12 @@ def check_args(self): "Currently we only support generate with --distributed" ) elif self.builder_args.tp < 2: - raise RuntimeError("TP degree must be at least 2 for distributed inference") + raise ValueError("TP degree must be at least 2 for distributed inference") + elif self.model_name not in NAME_TO_DISTRIBUTION_AND_DTYPE.keys(): + raise ValueError( + f"Distributed inference currently only supports then following models: {list(NAME_TO_DISTRIBUTION_AND_DTYPE.keys())}" + ) + elif self.builder_args.chpt_from == "torchchat": + raise ValueError( + f"Distributed inference currently only supports HF checkpoints" + ) diff --git a/torchchat/generate.py b/torchchat/generate.py index 6e9e15585..9f051bf5a 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -1233,6 +1233,7 @@ def main(args): pass else: dist_gen = DistributedGenerator( + args.model, builder_args, tokenizer_args, generator_args, From 9ec55fbae09c2349bf78a23aa5100e8f2640b397 Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Thu, 24 Oct 2024 12:42:12 -0700 Subject: [PATCH 19/24] disable calling dist_run.py directly for now --- torchchat/distributed/dist_run.py | 57 +++++++++++++++---------------- 1 file changed, 28 insertions(+), 29 deletions(-) diff --git a/torchchat/distributed/dist_run.py b/torchchat/distributed/dist_run.py index ddef3efe3..389ae41c1 100644 --- a/torchchat/distributed/dist_run.py +++ b/torchchat/distributed/dist_run.py @@ -287,7 +287,7 @@ def _cleanup(): # "Can you explain what is the purpose of back propagation in neural networks?", "Who is Santa Claus?", "Where does Santa live?", - # "Who is Abraham Lincoln?", + "Who is Abraham Lincoln?", # "How are models trained?", ] @@ -600,31 +600,30 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: f"{color.green}Success{color.white} - {color.blue}Rank {rank} has completed.{color.reset}" ) - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - """parser.add_argument( - "model_name", - type=str, - default="llama3", - help="Name of the model to load", - # choices=NAME_TO_DISTRIBUTION_AND_DTYPE.keys(), - ) - """ - parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel degree") - parser.add_argument( - "--ntokens", - type=int, - default=40, - help="Number of tokens to generate", - ) - parser.add_argument( - "--chpt-from", - type=str, - default="hf", # TODO: change to torchchat once we support it well - help="Checkpoint format to load from", - choices=["hf", "torchchat"], - ) - args = parser.parse_args() - - main(args) +# TODO: remove or make it work again +# if __name__ == "__main__": +# parser = argparse.ArgumentParser() +# parser.add_argument( +# "model_name", +# type=str, +# default="llama3", +# help="Name of the model to load", +# choices=NAME_TO_DISTRIBUTION_AND_DTYPE.keys(), +# ) +# parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel degree") +# parser.add_argument( +# "--ntokens", +# type=int, +# default=40, +# help="Number of tokens to generate", +# ) +# parser.add_argument( +# "--chpt-from", +# type=str, +# default="hf", # TODO: change to torchchat once we support it well +# help="Checkpoint format to load from", +# choices=["hf", "torchchat"], +# ) +# args = parser.parse_args() + +# main() From 80f8138b9c955918ff64a977f3832abd73f794b1 Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Thu, 24 Oct 2024 12:44:51 -0700 Subject: [PATCH 20/24] Restore original dist_run.py for now --- dist_run.py | 625 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 625 insertions(+) create mode 100644 dist_run.py diff --git a/dist_run.py b/dist_run.py new file mode 100644 index 000000000..72fc5855a --- /dev/null +++ b/dist_run.py @@ -0,0 +1,625 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Example run command: +# torchrun --nproc-per-node 4 dist_run.py llama2-7b-chat --pp 2 +# torchrun --nproc-per-node 4 dist_run.py llama3 --pp 2 + +import argparse +import os +from enum import auto, Enum +from pathlib import Path +from types import SimpleNamespace +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.distributed as dist +from torch.distributed.pipelining import PipelineStage, ScheduleGPipe +from torchchat.cli.builder import _initialize_tokenizer, TokenizerArgs + +# TODO - these are not distributed specific, consider moving to new package +from torchchat.distributed.checkpoint_utils import ( + get_hf_config_file, + load_weights_from_hf_format, + load_weights_from_torchchat_format, +) + +from torchchat.distributed.logging_utils import SingletonLogger +from torchchat.distributed.utils import ( + bytes_to_readable, + Color as color, + CUDATrackTime, + get_module_size, + get_num_params, + GPUMemoryMonitor, +) +from torchchat.model import ModelArgs, Transformer, TransformerArgs +from torchchat.utils.build_utils import set_precision + +try: + from tokenizer.tiktoken import Tokenizer as TiktokenTokenizer +except ImportError: + TiktokenTokenizer = None +try: + from sentencepiece import SentencePieceProcessor +except ImportError: + SentencePieceProcessor = None + + +logger = SingletonLogger.get_logger() +_tokenizer_type = None # global variable to store the tokenizer type + +# Using model name to identify the model to load, for example "llama2-7b-chat". +# You can change it to other values listed below. +# For details on the name-to-distribution mapping, see README.md or models.json. +NAME_TO_DISTRIBUTION_AND_DTYPE = { + "llama2-7b-chat": ("meta-llama/Llama-2-7b-chat-hf", torch.float16), + "llama3": ("meta-llama/Meta-Llama-3-8B-Instruct", torch.bfloat16), +} + + +class TokenizerType(Enum): + Tiktoken = auto() + SentencePiece = auto() + + +def _init_distributed(): + dist.init_process_group("nccl") + rank = dist.get_rank() + world_size = dist.get_world_size() + # Assuming same number of GPUs per node + torch.cuda.set_device(rank % torch.cuda.device_count()) + return rank, world_size + + +def _create_device_mesh(mesh_dimensions): + return dist.init_device_mesh("cuda", mesh_dimensions, mesh_dim_names=("pp", "tp")) + + +def dict_to_args(dictionary: Dict[str, Any]) -> SimpleNamespace: + return SimpleNamespace(**dictionary) + + +def _build_chat_tokenizer( + model_name: str, + model_base_name: Optional[str] = None, +) -> SentencePieceProcessor | TiktokenTokenizer: + """Builds a tokenizer for the given model name, and sets the global tokenizer type variable""" + + global _tokenizer_type + + # Try to infer the model base name from the model name: + # e.g. "llama2-7b-chat" -> "llama2" + if model_base_name is None: + model_base_name = model_name.split("-")[0] + logger.info( + f"Using model base name '{model_base_name}' to build tokenizer. " + "If not found, please specify it using the `model_base_name` argument." + ) + + # Create base args for tokenizer + default_model_dir = Path( + os.getenv("TORCHCHAT_MODELDIR", "~/.torchchat/model-cache") + ).expanduser() + + tokenconfig = { + "model_directory": default_model_dir, + "model": model_base_name, + "tokenizer_path": None, + } + args = dict_to_args(tokenconfig) + tokenizer_args = TokenizerArgs.from_args(args) + tokenizer = _initialize_tokenizer(tokenizer_args) + assert tokenizer is not None, f"Failed to get tokenizer using {tokenconfig=}" + logger.info( + f"using tokenizer = {tokenizer.__class__.__module__}.{tokenizer.__class__.__name__}" + ) + # set global variable _tokenizer_type + if isinstance(tokenizer, TiktokenTokenizer): + _tokenizer_type = TokenizerType.Tiktoken + elif isinstance(tokenizer, SentencePieceProcessor): + _tokenizer_type = TokenizerType.SentencePiece + else: + raise ValueError(f"Unknown tokenizer type: {tokenizer.__class__}") + + logger.info(f"tokenizer type = {_tokenizer_type}") + return tokenizer + + +def _load_model_weights( + stage_module: torch.nn.Module, + distribution: str, + device: torch.device, + model_config: ModelArgs, + chpt_from: str, +): + """Load the weights from the safetensor file(s) into the model stage. + Model config is needed b/c we permute wq and wk weights based on attn heads. + + Args: + stage_module (torch.nn.Module): The model stage to load the weights into. + distribution (str): The distribution name, e.g. "meta-llama/Meta-Llama-3-8B-Instruct". + device (torch.device): The device to load the weights onto. + model_config (ModelArgs): The model config. + chpt_from (str): The checkpoint format to load the weights from, e.g. "torchchat" or "hf". + """ + if chpt_from == "hf": + # This format stands for: index file + multiple binary files + load_weights_from_hf_format(stage_module, distribution, device, model_config) + elif chpt_from == "torchchat": + # This format stands for: + # single binary file, OR + # multiple binary files without index files. + load_weights_from_torchchat_format( + stage_module, distribution, device, model_config + ) + else: + raise ValueError(f"Unknown checkpoint format: {chpt_from}") + + +def _encode_strings( + strings: List[str], + tokenizer, + bos: bool, + device: torch.device, + dtype=torch.int64, +) -> List[torch.Tensor]: + """Encode a list of prompt strings into a list of tensor token ids.""" + encoded_list = [] + for string in strings: + tokens = tokenizer.encode(string) + if bos: + tokens = [tokenizer.bos_id()] + tokens + encoded_list.append(torch.tensor(tokens, dtype=dtype, device=device)) + return encoded_list + + +def _create_padded_prompts( + input_ids_list: List[torch.Tensor], + tokenizer, + seqlen: int, + start_pos: int, + device: torch.device, + pad_token_id: Optional[int] = None, +) -> Tuple[torch.Tensor, List[int]]: + """ + Create a padded tensor for multiple encoded input prompts. + + Returns: + Tuple[torch.Tensor, List[int]]: A tuple containing the padded tensor and a list of prompt lengths. + """ + pad_token_id = pad_token_id if pad_token_id is not None else tokenizer.eos_id() + + # Find the maximum prompt length + max_prompt_len = max(ids.size(0) for ids in input_ids_list) + + # Calculate the buffer size + max_new_tokens = max(0, min(seqlen - start_pos, seqlen - max_prompt_len)) + token_buffer_size = max_prompt_len + max_new_tokens + + # Create the padded batch tensor + batch_size = len(input_ids_list) + batch_seq = torch.full( + (batch_size, token_buffer_size), pad_token_id, dtype=torch.int64, device=device + ) + + prompt_lengths = [] + for i, input_ids in enumerate(input_ids_list): + prompt_len = input_ids.size(0) + batch_seq[i, :prompt_len] = input_ids + prompt_lengths.append(prompt_len) + + return batch_seq, prompt_lengths + + +def _batch_decode_next_tokens( + output: torch.Tensor, + pos: List[int], + step: int = -1, + temperature: float = 1.0, + topk: int = 10, +) -> torch.Tensor: + """ + Decode the next token for each prompt in the batch. Adds temperature option for non-deterministic decoding. + + Args: + output (torch.Tensor): The output tensor to decode. + pos (List[int]): The positions of the `output` to decode in the sequence length dimension. + step (int): Step indicator. If -1, use positions from `pos`. Otherwise, use the first token. + temperature (float): Sampling temperature for non-deterministic decoding. + + Returns: + torch.Tensor: Decoded token ids. + """ + batch_size, seq_len, vocab_size = output.shape + + if step != -1: + # `pos` is not provided, so we can use the first token + next_token_logits = output[:, 0, :] + else: + # get the logits for each prompt at the specified positions + next_token_logits = output[torch.arange(batch_size), torch.tensor(pos) - 1] + + if temperature != 1.0: + next_token_logits = next_token_logits / temperature + + # Uses top-k sampling if temperature is not 1.0, otherwise use argmax + if temperature != 1.0: + top_k = min(topk, vocab_size) # Ensure top-k is not greater than vocab size + top_k_logits, top_k_indices = torch.topk(next_token_logits, k=top_k, dim=-1) + probs = torch.softmax(top_k_logits, dim=-1) + next_token_indices = torch.multinomial(probs, num_samples=1).squeeze(-1) + next_tokens = top_k_indices.gather( + -1, next_token_indices.unsqueeze(-1) + ).squeeze(-1) + else: + # Argmax (deterministic) + next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True) + + # Token ids in int tensor form + return next_tokens + + +def _update_padded_sequence( + padded_sequence: torch.Tensor, + new_token: torch.Tensor, + prompt_lengths: List[int], +) -> None: + for i in range(len(prompt_lengths)): + padded_sequence[i, prompt_lengths[i]] = new_token[i, 0] + # logger.info(f"updated prompt {i} with new token {new_token[i, 0]}") + + +# Decode token id into string and print it +def _decode_in_flight(token, tokenizer, tp_rank): + """decode token ids for all prompts in the batch and log them""" + # `token` is a tensor of shape (batch_size, 1). + # For TiktokenTokenizer, we need to squeeze it to 1D. + # For SentencePieceProcessor, we don't. + if isinstance(tokenizer, TiktokenTokenizer): + token = torch.squeeze(token, dim=1) + token_str = tokenizer.decode(token.tolist()) + # print the token string on tp rank 0 + if tp_rank == 0: + logger.info( + f"{color.green} responses ====>>>> " + f"{color.blue} {token_str} {color.reset}" + ) + + +def _cleanup(): + dist.barrier() + dist.destroy_process_group() + + +prompt = [ + "What is Snow?", + # "Can you explain what is the purpose of back propagation in neural networks?", + "Who is Santa Claus?", + "Where does Santa live?", + # "Who is Abraham Lincoln?", + # "How are models trained?", +] + + +def main(args): + model_name = "llama3" # args.model_name + pp_degree = args.pp + + rank, world_size = _init_distributed() + + gpu_memory_monitor = GPUMemoryMonitor("cuda") + logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}") + + distribution, model_dtype = NAME_TO_DISTRIBUTION_AND_DTYPE[model_name] + logger.info(f"Using model weights from {distribution} and dtype {model_dtype}") + + # Model-level config + model_config = ModelArgs.from_name(distribution) + # Transformer-level config + config = TransformerArgs.from_params(model_config.transformer_args["text"]) + logger.info(f"Transformer Config: {config}") + + tokenizer = _build_chat_tokenizer(model_name) + + set_precision(model_dtype) + logger.info(f"Using cache precision {model_dtype}") + + hf_config = get_hf_config_file(distribution) + if hf_config is None: + raise ValueError(f"Config file not found for model id {distribution}") + + # Validate pipeline degree + assert world_size % pp_degree == 0 + assert config.n_layers % pp_degree == 0 + + # Tensor parallel is enabled in this program + tp_degree = world_size // pp_degree + + # Create device mesh + mesh_dimensions = (pp_degree, tp_degree) + device_mesh = _create_device_mesh(mesh_dimensions) + tp_mesh = device_mesh["tp"] + pp_mesh = device_mesh["pp"] + logger.info(f"Created device mesh: {device_mesh}\n{tp_mesh=}, {pp_mesh=}") + + tp_rank = tp_mesh.get_local_rank() + pp_rank = pp_mesh.get_local_rank() + tp_group = tp_mesh.get_group() + pp_group = pp_mesh.get_group() + logger.info(f"{pp_degree=}, {tp_degree=}") + + # Convenience variables + first_pp_rank = 0 + last_pp_rank = pp_degree - 1 + + # Assuming same number of GPUs per node + device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") + + # Fill in PP configs + config.stage_idx = pp_rank + config.n_stages = pp_degree + + with torch.device("meta"): + # TODO: we should create model instead of Transformer + model = Transformer(config) + + # Distribute model on TP mesh + # (Surprisingly, this works even though model is on meta device and mesh is of + # cuda devices) + model.distribute(tp_mesh) + if rank == 0: + logger.info(f"Model: {model}") + + # Load weights + logger.info(f"Loading weights for {pp_rank=} on {device=}") + with CUDATrackTime() as timer: + _load_model_weights(model, distribution, device, config, args.chpt_from) + + logger.info( + f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" + ) + + # Batch size. Since we push batches dynamically through the pipeline rather + # than chunking them, this is effectively micro-batch size in pipeline + # sense. Thus it is interchangeable with micro-batch size below. + batch_size = len(prompt) + seqlen_prefill = 1024 # sequence length + dim = 4096 # embedding dimension + + # Setup KV caches (after model distribution) + # The number of cache lanes is the same as the maximum number of + # micro-batches that can be "in flight" in parallel -- imagine each + # micro-batch takes 1 "pipeline lane," they need distinct KV cache spaces. + # When decoding is done for certain micro-batches, we can reuse the KV cache + # lanes. + # TODO: bump up the lane count + pipeline_lanes = 1 + with device: + model.setup_caches(batch_size, seqlen_prefill, cache_lanes=pipeline_lanes) + + # info on stage size and params + stage_size = get_module_size(model) + stage_size_formatted = bytes_to_readable(stage_size) + stage_num_params = get_num_params(model) + logger.info( + f"Stage {rank} has {color.blue}{stage_num_params} params{color.reset}, Size: {color.blue}{stage_size_formatted}{color.reset}" + ) + + # Setup input position (input_pos) for prefill: a list of increasing integers from 0 to seqlen + input_pos = torch.arange(seqlen_prefill, device=device) + model.eval() + + # Helper function to get example inputs and outputs for the stages. + def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: + mb_ids = torch.randint( + 0, config.vocab_size, (batch_size, seqlen), device=device + ) + activation = torch.rand( + batch_size, seqlen, dim, device=device, dtype=model_dtype + ) + logits = torch.rand( + batch_size, seqlen, config.vocab_size, device=device, dtype=model_dtype + ) + example_inputs = (mb_ids if pp_rank == first_pp_rank else activation,) + example_outputs = (logits if pp_rank == last_pp_rank else activation,) + return example_inputs, example_outputs + + # Create prefill stage + logger.info(f"Creating pipeline stage for prefill {pp_rank=}, {pp_degree=}") + example_inputs, example_outputs = get_example_ins_outs(seqlen_prefill) + prefill_stage = PipelineStage( + model, + pp_rank, + pp_degree, + device, + input_args=example_inputs, + output_args=example_outputs, + group=pp_group, + ) + + # Create schedule + # Number of micro-batches for the schedule is 1, because each step() call we + # only push 1 micro-batch into the pipeline. But we can continuously push + # new micro-batches into the pipeline as they arrive, achieving same + # pipelining effect. + prefiller = ScheduleGPipe(prefill_stage, 1) + + start_pos = 0 + + # Need these global ids due to the API definition of dist.send and recv + first_pp_rank_global_id = dist.get_global_rank(pp_group, first_pp_rank) + last_pp_rank_global_id = dist.get_global_rank(pp_group, last_pp_rank) + + # encode the prompt + input_ids = _encode_strings( + prompt, tokenizer, bos=True, device=device, dtype=torch.int64 + ) + + # create a padded tensor for the input prompt + padded_sequence, prompt_lengths = _create_padded_prompts( + input_ids, tokenizer, seqlen_prefill, start_pos, device + ) + + # Need these global ids due to the API definition of dist.send and recv + first_pp_rank_global_id = dist.get_global_rank(pp_group, first_pp_rank) + last_pp_rank_global_id = dist.get_global_rank(pp_group, last_pp_rank) + + # New token generated each iteration + # need a row dimension for each prompt in the batch + new_token = torch.zeros(batch_size, 1, device=device, dtype=torch.int64) + # Store the generated tokens + res = [] + + # Prefill phase + # Run context input through pipeline + # TODO: we need to pass `input_pos` and `cache_lane` to each stage. + lane = 0 + kwargs = {"input_pos": input_pos, "cache_lane": lane} + with torch.no_grad(), CUDATrackTime() as timer: + if pp_rank == first_pp_rank: + output = prefiller.step(padded_sequence, **kwargs) + elif pp_rank == last_pp_rank: + output = prefiller.step(**kwargs) + else: # middle pp ranks + prefiller.step(**kwargs) + + logger.info( + f"{color.green}Prefilling time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" + ) + + # Decode the output -- first generated token + if pp_rank == last_pp_rank: + logger.info(f"{color.green}Decoding...{prompt_lengths=}{color.reset}") + new_token = _batch_decode_next_tokens(output, prompt_lengths) + res.append(new_token) + if not args.disable_in_flight_decode: + _decode_in_flight(new_token, tokenizer, tp_rank) + + # seqlen = 1 now + seqlen_decode = 1 + input_pos = torch.tensor([prompt_lengths[0]], device=device) + + # Create decode stage + logger.info(f"Creating pipeline stage for decode {pp_rank=}, {pp_degree=}") + example_inputs, example_outputs = get_example_ins_outs(seqlen_decode) + decode_stage = PipelineStage( + model, + pp_rank, + pp_degree, + device, + input_args=example_inputs, + output_args=example_outputs, + group=pp_group, + ) + # create schedule + decoder = ScheduleGPipe(decode_stage, 1) + + # Decoding + with torch.no_grad(), CUDATrackTime() as timer: + for step in range(args.ntokens - 1): + kwargs = {"input_pos": input_pos, "cache_lane": lane} + # sendrecv between last and first ranks, only if: + # first_pp_rank != last_pp_rank. + if pp_rank == last_pp_rank and pp_rank != first_pp_rank: + dist.send( + new_token, + dst=first_pp_rank_global_id, + group=pp_group, + ) + elif pp_rank == first_pp_rank and pp_rank != last_pp_rank: + dist.recv( + new_token, + src=last_pp_rank_global_id, + group=pp_group, + ) + + # Run data through pipeline + if pp_rank == first_pp_rank: + output = decoder.step(new_token, **kwargs) + elif pp_rank == last_pp_rank: + output = decoder.step(**kwargs) + else: # middle pp ranks + decoder.step(**kwargs) + + # Decode the output + if pp_rank == last_pp_rank: + new_token = _batch_decode_next_tokens(output, prompt_lengths, step) + res.append(new_token) + if not args.disable_in_flight_decode: + _decode_in_flight(new_token, tokenizer, tp_rank) + + # Increment input position + input_pos += 1 + + logger.info( + f"{color.green}Decoding time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" + ) + + # Display the decoding results + + # output formatted response via last pp group and tp rank 0 + if pp_rank == last_pp_rank and tp_rank == 0: + # `res` is a list of tensors, each being a batch of generated token ids. + # We need to concatenate them to get the full sequence of generated + # token ids. Thus cat'ing along dim 1. + res = torch.cat(res, dim=1) + res_list = res.tolist() + if _tokenizer_type == TokenizerType.Tiktoken: + # For TiktokenTokenizer, we need to decode prompt by prompt. + # TODO: is there a better way to do this? + responses = [tokenizer.decode(sequence) for sequence in res_list] + elif _tokenizer_type == TokenizerType.SentencePiece: # SentencePieceProcessor + # For SentencePieceProcessor, we can decode the entire 2D list at once. + responses = tokenizer.decode(res_list) + else: + raise ValueError(f"Unknown tokenizer type {_tokenizer_type}") + + # Show prompts and responses + for prompt_text, response_text in zip(prompt, responses): + logger.info(f"Prompt: {color.green}{prompt_text} {color.reset}") + logger.info(f"Response: {color.red}{response_text} {color.reset}") + + # Cleanup + _cleanup() + logger.info( + f"{color.green}Success{color.white} - {color.blue}Rank {rank} has completed.{color.reset}" + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + """parser.add_argument( + "model_name", + type=str, + default="llama3", + help="Name of the model to load", + # choices=NAME_TO_DISTRIBUTION_AND_DTYPE.keys(), + ) + """ + parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel degree") + parser.add_argument( + "--ntokens", + type=int, + default=40, + help="Number of tokens to generate", + ) + parser.add_argument( + "--disable-in-flight-decode", + action="store_true", + default=False, + help="Whether to decode token into string in flight", + ) + parser.add_argument( + "--chpt-from", + type=str, + default="hf", # TODO: change to torchchat once we support it well + help="Checkpoint format to load from", + choices=["hf", "torchchat"], + ) + args = parser.parse_args() + + main(args) From 99606ab348ff9d3f518b82c14b5eac33262d1936 Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Thu, 24 Oct 2024 12:59:25 -0700 Subject: [PATCH 21/24] disable _maybe_parallelize_model again --- torchchat/cli/builder.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 4db385504..511cf1f35 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -92,7 +92,9 @@ def __post_init__(self): ] for param, param_msg in ignored_params: if param: - print(f"Warning: {param_msg} ignored because an exported DSO or PTE path was specified") + print( + f"Warning: {param_msg} ignored because an exported DSO or PTE path was specified" + ) else: self.prefill_possible = True @@ -495,7 +497,7 @@ def _load_model(builder_args: BuilderArgs) -> Model: # model = _init_model_on_meta_device(builder_args) else: model = _load_model_default(builder_args) - model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims) + # model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims) model = model.to(device=builder_args.device, dtype=builder_args.precision) return model.eval() From 4b8cdcb5e02293ef7a422ec1e2c44938fa0995a6 Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Thu, 24 Oct 2024 14:21:27 -0700 Subject: [PATCH 22/24] Reenable arg.model_name in dist_run.py --- dist_run.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dist_run.py b/dist_run.py index 72fc5855a..30bf92669 100644 --- a/dist_run.py +++ b/dist_run.py @@ -306,7 +306,7 @@ def _cleanup(): def main(args): - model_name = "llama3" # args.model_name + model_name = args.model_name pp_degree = args.pp rank, world_size = _init_distributed() @@ -592,14 +592,14 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: if __name__ == "__main__": parser = argparse.ArgumentParser() - """parser.add_argument( + parser.add_argument( "model_name", type=str, default="llama3", help="Name of the model to load", - # choices=NAME_TO_DISTRIBUTION_AND_DTYPE.keys(), + choices=NAME_TO_DISTRIBUTION_AND_DTYPE.keys(), ) - """ + parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel degree") parser.add_argument( "--ntokens", From b8f88fdd31b60a176539ca4f6b70d515fa92afaf Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Thu, 24 Oct 2024 15:29:17 -0700 Subject: [PATCH 23/24] Use singleton logger instead of print in generate --- torchchat/distributed/generate.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/torchchat/distributed/generate.py b/torchchat/distributed/generate.py index 0abaebf29..6b36f8ca0 100644 --- a/torchchat/distributed/generate.py +++ b/torchchat/distributed/generate.py @@ -20,6 +20,9 @@ import torch.multiprocessing as mp from torchchat.cli.builder import BuilderArgs, TokenizerArgs from torchchat.distributed.dist_run import NAME_TO_DISTRIBUTION_AND_DTYPE +from torchchat.distributed.logging_utils import SingletonLogger + +logger = SingletonLogger.get_logger() def _setup_env(world_size: int, rank: int, target: callable, *args, **kwargs): @@ -37,7 +40,7 @@ def _launch_distributed_inference( model_name: str, builder_args: BuilderArgs, tokenizer_args: TokenizerArgs ) -> tuple[List]: # create programmatic elastic launch - print("Launching distributed inference ...") + logger.info("Launching distributed inference ...") num_processes_per_node = builder_args.pp * builder_args.tp @@ -59,7 +62,9 @@ def _launch_distributed_inference( for pipe in pipes: response = pipe.recv() - print(f"Done launching distributed inference on {num_processes_per_node} GPUs.") + logger.info( + f"Done launching distributed inference on {num_processes_per_node} GPUs." + ) return procs, pipes From 2d37d27d2a3e321534142346760547a5978db7e4 Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Thu, 24 Oct 2024 16:04:08 -0700 Subject: [PATCH 24/24] Address PR comments; try/expect in launch_dist_inference; added comments --- torchchat/distributed/generate.py | 37 ++++++++++++++++++++++--------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/torchchat/distributed/generate.py b/torchchat/distributed/generate.py index 6b36f8ca0..51c472e4a 100644 --- a/torchchat/distributed/generate.py +++ b/torchchat/distributed/generate.py @@ -39,7 +39,7 @@ def _setup_env(world_size: int, rank: int, target: callable, *args, **kwargs): def _launch_distributed_inference( model_name: str, builder_args: BuilderArgs, tokenizer_args: TokenizerArgs ) -> tuple[List]: - # create programmatic elastic launch + # launch distributed inference worker, each worker gets a pipe to communicate with the main process logger.info("Launching distributed inference ...") num_processes_per_node = builder_args.pp * builder_args.tp @@ -50,17 +50,25 @@ def _launch_distributed_inference( pipes = [] procs = [] - for rank in range(num_processes_per_node): - server_pipe, client_pipe = mp.Pipe(duplex=True) - pipes.append(server_pipe) - proc = mp.Process( - target=partial(_setup_env, num_processes_per_node, rank, main), - args=(model_name, builder_args, tokenizer_args, client_pipe), - ) - proc.start() + try: + for rank in range(num_processes_per_node): + server_pipe, client_pipe = mp.Pipe(duplex=True) + pipes.append(server_pipe) + procs.append( + mp.Process( + target=partial(_setup_env, num_processes_per_node, rank, main), + args=(model_name, builder_args, tokenizer_args, client_pipe), + ) + ) + procs[-1].start() - for pipe in pipes: - response = pipe.recv() + for pipe in pipes: + assert pipe.recv() == "ready", "Starting the worker failed" + except Exception as e: + logger.error(f"Error during distributed inference: {str(e)}") + for p in procs: + p.kill() + raise e logger.info( f"Done launching distributed inference on {num_processes_per_node} GPUs." @@ -105,11 +113,13 @@ def __init__( self.loop = loop def schedule_request(self, req: Request): + # add request to queue and create deque and async event for response self.req_to_states[req.request_id] = asyncio.Event() self.req_to_results[req.request_id] = deque() self.request_queue.put(req) def process_requests_loop(self): + # Continuously process requests (one at a time for now), results are routed into the requests deque while True: req = self.request_queue.get() if req == "stop": @@ -127,6 +137,7 @@ def process_requests_loop(self): running &= not outputs[0].is_finished async def wait_for_request(self, req: Request) -> Output: + # Wait for request to deliver result, uses event to trigger and reads from left side of deque is_finished = False while not is_finished: await self.req_to_states[req.request_id].wait() @@ -138,6 +149,7 @@ async def wait_for_request(self, req: Request) -> Output: del self.req_to_results[req.request_id] def step(self) -> List[Output]: + # Make a prefill or decoding step and receive results responses = [] # TODO: Implement a scheduler to handle the requests if len(self.in_flight_requests) > 0: @@ -166,6 +178,7 @@ def step(self) -> List[Output]: text, token_ids = v outputs.append( Output( + # TODO: Look for tokenizer.eos_id as well is_finished=self.current_step >= self.generator_args.max_new_tokens, text=text, token=token_ids, @@ -218,6 +231,7 @@ def __init__( atexit.register(self.shutdown) def shutdown(self): + # Stop all processes and threads self.scheduler.request_queue.put("stop") self.scheduler_thread.join() @@ -227,6 +241,7 @@ def shutdown(self): p.kill() def generate(self, text): + # Function to generate text from prompt req = Request.new_request(text) self.scheduler.schedule_request(req)