diff --git a/dist_run.py b/dist_run.py index f8597c563..09e0be725 100644 --- a/dist_run.py +++ b/dist_run.py @@ -4,26 +4,28 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +# Example run command: +# torchrun --nproc-per-node 4 dist_run.py llama2-7b-chat --pp 2 +# torchrun --nproc-per-node 4 dist_run.py llama3 --pp 2 + import argparse import os from pathlib import Path from types import SimpleNamespace from typing import Any, Dict, List, Optional, Tuple -# Run command: -# torchrun --nproc-per-node 4 dist_run.py import torch import torch.distributed as dist -from distributed.logging_utils import SingletonLogger +from torchchat.distributed.logging_utils import SingletonLogger # TODO - these are not distributed specific, consider moving to new package -from distributed.safetensor_utils import ( +from torchchat.distributed.safetensor_utils import ( get_hf_config_file, get_hf_weight_map_and_path, load_safetensor_weights, ) -from distributed.utils import ( +from torchchat.distributed.utils import ( bytes_to_readable, Color as color, CUDATrackTime, diff --git a/distributed/__init__.py b/distributed/__init__.py deleted file mode 100644 index 894e96fff..000000000 --- a/distributed/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from distributed.checkpoint import load_checkpoints_to_model -from distributed.logging_utils import SingletonLogger -from distributed.parallel_config import ParallelDims -from distributed.parallelize_llama import parallelize_llama -from distributed.utils import init_distributed -from distributed.world_maker import launch_distributed diff --git a/run_dist.sh b/run_dist.sh deleted file mode 100644 index e6e3bb133..000000000 --- a/run_dist.sh +++ /dev/null @@ -1,7 +0,0 @@ -#export CUDA_VISIBLE_DEVICES=4,5,6,7 -PORT=${1:-29501} -NGPU=${NGPU:-"4"} -LOG_RANK=${LOG_RANK:-0,1,2,3} -torchrun --nproc-per-node=$NGPU --master_port=$PORT \ ---local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ -dist_run.py --pp 2 llama3 diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 1049b346f..0f70f77cd 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -21,7 +21,7 @@ except ImportError: pass -from distributed import launch_distributed, ParallelDims, parallelize_llama +from torchchat.distributed import launch_distributed, ParallelDims, parallelize_llama from torch.distributed.device_mesh import DeviceMesh diff --git a/distributed/README.md b/torchchat/distributed/README.md similarity index 100% rename from distributed/README.md rename to torchchat/distributed/README.md diff --git a/torchchat/distributed/__init__.py b/torchchat/distributed/__init__.py new file mode 100644 index 000000000..db96e3909 --- /dev/null +++ b/torchchat/distributed/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchchat.distributed.checkpoint import load_checkpoints_to_model +from torchchat.distributed.logging_utils import SingletonLogger +from torchchat.distributed.parallel_config import ParallelDims +from torchchat.distributed.parallelize_llama import parallelize_llama +from torchchat.distributed.utils import init_distributed +from torchchat.distributed.world_maker import launch_distributed diff --git a/distributed/checkpoint.py b/torchchat/distributed/checkpoint.py similarity index 100% rename from distributed/checkpoint.py rename to torchchat/distributed/checkpoint.py diff --git a/distributed/config_manager.py b/torchchat/distributed/config_manager.py similarity index 98% rename from distributed/config_manager.py rename to torchchat/distributed/config_manager.py index fa0f5ee4c..db71fb5f6 100644 --- a/distributed/config_manager.py +++ b/torchchat/distributed/config_manager.py @@ -12,7 +12,7 @@ import torch -from distributed.logging_utils import SingletonLogger +from torchchat.distributed.logging_utils import SingletonLogger logger = SingletonLogger.get_logger() diff --git a/distributed/dtensor_utils.py b/torchchat/distributed/dtensor_utils.py similarity index 97% rename from distributed/dtensor_utils.py rename to torchchat/distributed/dtensor_utils.py index 6ce2c3fb4..9e57da428 100644 --- a/distributed/dtensor_utils.py +++ b/torchchat/distributed/dtensor_utils.py @@ -4,7 +4,7 @@ from collections import defaultdict -from distributed.logging_utils import SingletonLogger +from torchchat.distributed.logging_utils import SingletonLogger logger = SingletonLogger.get_logger() diff --git a/distributed/force_download.py b/torchchat/distributed/force_download.py similarity index 100% rename from distributed/force_download.py rename to torchchat/distributed/force_download.py diff --git a/distributed/inference_configs/llama3_8B.toml b/torchchat/distributed/inference_configs/llama3_8B.toml similarity index 100% rename from distributed/inference_configs/llama3_8B.toml rename to torchchat/distributed/inference_configs/llama3_8B.toml diff --git a/distributed/logging_utils.py b/torchchat/distributed/logging_utils.py similarity index 100% rename from distributed/logging_utils.py rename to torchchat/distributed/logging_utils.py diff --git a/distributed/parallel_config.py b/torchchat/distributed/parallel_config.py similarity index 95% rename from distributed/parallel_config.py rename to torchchat/distributed/parallel_config.py index 71f568332..cc439b42e 100644 --- a/distributed/parallel_config.py +++ b/torchchat/distributed/parallel_config.py @@ -8,7 +8,7 @@ from torch.distributed.device_mesh import init_device_mesh -from distributed.logging_utils import SingletonLogger +from torchchat.distributed.logging_utils import SingletonLogger logger = SingletonLogger.get_logger() @dataclass diff --git a/distributed/parallelize_llama.py b/torchchat/distributed/parallelize_llama.py similarity index 97% rename from distributed/parallelize_llama.py rename to torchchat/distributed/parallelize_llama.py index 0b1dca4cf..a965dd264 100644 --- a/distributed/parallelize_llama.py +++ b/torchchat/distributed/parallelize_llama.py @@ -11,9 +11,9 @@ parallelize_module) -from distributed.parallel_config import ParallelDims +from torchchat.distributed.parallel_config import ParallelDims -from distributed.logging_utils import SingletonLogger +from torchchat.distributed.logging_utils import SingletonLogger logger = SingletonLogger.get_logger() diff --git a/distributed/run_dist_inference.sh b/torchchat/distributed/run_dist_inference.sh similarity index 100% rename from distributed/run_dist_inference.sh rename to torchchat/distributed/run_dist_inference.sh diff --git a/distributed/safetensor_utils.py b/torchchat/distributed/safetensor_utils.py similarity index 98% rename from distributed/safetensor_utils.py rename to torchchat/distributed/safetensor_utils.py index d5bab6c1f..39eaee71b 100644 --- a/distributed/safetensor_utils.py +++ b/torchchat/distributed/safetensor_utils.py @@ -14,14 +14,14 @@ from typing import Dict, Tuple, Set, Optional -from distributed.dtensor_utils import is_dtensor, load_into_dtensor +from torchchat.distributed.dtensor_utils import is_dtensor, load_into_dtensor _DEFAULT_SAFETENSOR_FILE_NAME = "model.safetensors.index.json" _CONFIG_NAME = "config.json" -from distributed.logging_utils import SingletonLogger +from torchchat.distributed.logging_utils import SingletonLogger logger = SingletonLogger.get_logger() diff --git a/distributed/utils.py b/torchchat/distributed/utils.py similarity index 99% rename from distributed/utils.py rename to torchchat/distributed/utils.py index bb2d2f23d..46ea5d9a1 100644 --- a/distributed/utils.py +++ b/torchchat/distributed/utils.py @@ -15,7 +15,7 @@ import torch -from distributed.logging_utils import SingletonLogger +from torchchat.distributed.logging_utils import SingletonLogger logger = SingletonLogger.get_logger() diff --git a/distributed/verification_utils.py b/torchchat/distributed/verification_utils.py similarity index 99% rename from distributed/verification_utils.py rename to torchchat/distributed/verification_utils.py index 7f02763b6..30632720e 100644 --- a/distributed/verification_utils.py +++ b/torchchat/distributed/verification_utils.py @@ -4,14 +4,12 @@ from collections import OrderedDict, defaultdict from torch._subclasses import FakeTensor import numpy as np -from distributed.dtensor_utils import is_dtensor +from torchchat.distributed.dtensor_utils import is_dtensor, SingletonLogger from typing import Dict, List, Tuple -from distributed.logging_utils import SingletonLogger logger = SingletonLogger.get_logger() - def record_module_dtypes(module): """Record the dtypes of all parameters and buffers in a module and return a dictionary of dtype -> list of names""" dtype_count = defaultdict(int) diff --git a/distributed/version.txt b/torchchat/distributed/version.txt similarity index 100% rename from distributed/version.txt rename to torchchat/distributed/version.txt diff --git a/distributed/world_maker.py b/torchchat/distributed/world_maker.py similarity index 90% rename from distributed/world_maker.py rename to torchchat/distributed/world_maker.py index 8606ae96c..a22120a4f 100644 --- a/distributed/world_maker.py +++ b/torchchat/distributed/world_maker.py @@ -9,14 +9,13 @@ from torch.distributed.device_mesh import DeviceMesh - -from distributed.parallel_config import ParallelDims -from distributed.utils import init_distributed +from torchchat.distributed.parallel_config import ParallelDims +from torchchat.distributed.utils import init_distributed +from torchchat.distributed.logging_utils import SingletonLogger from .config_manager import InferenceConfig -from distributed.logging_utils import SingletonLogger logger = SingletonLogger.get_logger()