Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions dist_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,28 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# Example run command:
# torchrun --nproc-per-node 4 dist_run.py llama2-7b-chat --pp 2
# torchrun --nproc-per-node 4 dist_run.py llama3 --pp 2

import argparse
import os
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Dict, List, Optional, Tuple

# Run command:
# torchrun --nproc-per-node 4 dist_run.py
import torch
import torch.distributed as dist

from distributed.logging_utils import SingletonLogger
from torchchat.distributed.logging_utils import SingletonLogger

# TODO - these are not distributed specific, consider moving to new package
from distributed.safetensor_utils import (
from torchchat.distributed.safetensor_utils import (
get_hf_config_file,
get_hf_weight_map_and_path,
load_safetensor_weights,
)
from distributed.utils import (
from torchchat.distributed.utils import (
bytes_to_readable,
Color as color,
CUDATrackTime,
Expand Down
12 changes: 0 additions & 12 deletions distributed/__init__.py

This file was deleted.

7 changes: 0 additions & 7 deletions run_dist.sh

This file was deleted.

2 changes: 1 addition & 1 deletion torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
except ImportError:
pass

from distributed import launch_distributed, ParallelDims, parallelize_llama
from torchchat.distributed import launch_distributed, ParallelDims, parallelize_llama

from torch.distributed.device_mesh import DeviceMesh

Expand Down
File renamed without changes.
12 changes: 12 additions & 0 deletions torchchat/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchchat.distributed.checkpoint import load_checkpoints_to_model
from torchchat.distributed.logging_utils import SingletonLogger
from torchchat.distributed.parallel_config import ParallelDims
from torchchat.distributed.parallelize_llama import parallelize_llama
from torchchat.distributed.utils import init_distributed
from torchchat.distributed.world_maker import launch_distributed
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import torch

from distributed.logging_utils import SingletonLogger
from torchchat.distributed.logging_utils import SingletonLogger
logger = SingletonLogger.get_logger()


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from collections import defaultdict

from distributed.logging_utils import SingletonLogger
from torchchat.distributed.logging_utils import SingletonLogger
logger = SingletonLogger.get_logger()


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from torch.distributed.device_mesh import init_device_mesh

from distributed.logging_utils import SingletonLogger
from torchchat.distributed.logging_utils import SingletonLogger
logger = SingletonLogger.get_logger()

@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
parallelize_module)


from distributed.parallel_config import ParallelDims
from torchchat.distributed.parallel_config import ParallelDims

from distributed.logging_utils import SingletonLogger
from torchchat.distributed.logging_utils import SingletonLogger
logger = SingletonLogger.get_logger()


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
from typing import Dict, Tuple, Set, Optional


from distributed.dtensor_utils import is_dtensor, load_into_dtensor
from torchchat.distributed.dtensor_utils import is_dtensor, load_into_dtensor


_DEFAULT_SAFETENSOR_FILE_NAME = "model.safetensors.index.json"
_CONFIG_NAME = "config.json"


from distributed.logging_utils import SingletonLogger
from torchchat.distributed.logging_utils import SingletonLogger
logger = SingletonLogger.get_logger()


Expand Down
2 changes: 1 addition & 1 deletion distributed/utils.py → torchchat/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch


from distributed.logging_utils import SingletonLogger
from torchchat.distributed.logging_utils import SingletonLogger
logger = SingletonLogger.get_logger()


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@
from collections import OrderedDict, defaultdict
from torch._subclasses import FakeTensor
import numpy as np
from distributed.dtensor_utils import is_dtensor
from torchchat.distributed.dtensor_utils import is_dtensor, SingletonLogger
from typing import Dict, List, Tuple

from distributed.logging_utils import SingletonLogger
logger = SingletonLogger.get_logger()



def record_module_dtypes(module):
"""Record the dtypes of all parameters and buffers in a module and return a dictionary of dtype -> list of names"""
dtype_count = defaultdict(int)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@

from torch.distributed.device_mesh import DeviceMesh


from distributed.parallel_config import ParallelDims
from distributed.utils import init_distributed
from torchchat.distributed.parallel_config import ParallelDims
from torchchat.distributed.utils import init_distributed
from torchchat.distributed.logging_utils import SingletonLogger

from .config_manager import InferenceConfig


from distributed.logging_utils import SingletonLogger
logger = SingletonLogger.get_logger()


Expand Down
Loading