Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 9cac4ad

Browse files
authored
Merge branch 'main' into android-update
2 parents 238c6bb + c454026 commit 9cac4ad

24 files changed

+54
-50
lines changed

dist_run.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,28 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# Example run command:
8+
# torchrun --nproc-per-node 4 dist_run.py llama2-7b-chat --pp 2
9+
# torchrun --nproc-per-node 4 dist_run.py llama3 --pp 2
10+
711
import argparse
812
import os
913
from pathlib import Path
1014
from types import SimpleNamespace
1115
from typing import Any, Dict, List, Optional, Tuple
1216

13-
# Run command:
14-
# torchrun --nproc-per-node 4 dist_run.py
1517
import torch
1618
import torch.distributed as dist
1719

18-
from distributed.logging_utils import SingletonLogger
20+
from torchchat.distributed.logging_utils import SingletonLogger
1921

2022
# TODO - these are not distributed specific, consider moving to new package
21-
from distributed.safetensor_utils import (
23+
from torchchat.distributed.safetensor_utils import (
2224
get_hf_config_file,
2325
get_hf_weight_map_and_path,
2426
load_safetensor_weights,
2527
)
26-
from distributed.utils import (
28+
from torchchat.distributed.utils import (
2729
bytes_to_readable,
2830
Color as color,
2931
CUDATrackTime,

distributed/__init__.py

Lines changed: 0 additions & 12 deletions
This file was deleted.

install/install_requirements.sh

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ REQUIREMENTS_TO_INSTALL=(
9090
# Rely on the latest tochtune for flamingo support
9191
(
9292
set -x
93-
$PIP_EXECUTABLE install git+https://github.com/pytorch/torchtune.git@18efc81dda1c537bb7c25058ff059b4623ccff58
93+
$PIP_EXECUTABLE install -I git+https://github.com/pytorch/torchtune.git@d002d45e3ec700fa770d9dcc61b02c59e2507bf6
9494
)
9595

9696
if [[ -x "$(command -v nvidia-smi)" ]]; then
@@ -99,3 +99,9 @@ if [[ -x "$(command -v nvidia-smi)" ]]; then
9999
$PYTHON_EXECUTABLE torchchat/utils/scripts/patch_triton.py
100100
)
101101
fi
102+
103+
104+
(
105+
set -x
106+
$PIP_EXECUTABLE install lm-eval=="0.4.2"
107+
)

install/requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ snakeviz
1414
sentencepiece
1515
numpy < 2.0
1616
gguf
17-
lm-eval==0.4.2
1817
blobfile
1918
tomli >= 1.1.0 ; python_version < "3.11"
2019
openai

run_dist.sh

Lines changed: 0 additions & 7 deletions
This file was deleted.

torchchat/cli/builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune
2020

21-
from distributed import launch_distributed, ParallelDims, parallelize_llama
21+
from torchchat.distributed import launch_distributed, ParallelDims, parallelize_llama
2222

2323
from torch.distributed.device_mesh import DeviceMesh
2424

File renamed without changes.

torchchat/distributed/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from torchchat.distributed.checkpoint import load_checkpoints_to_model
8+
from torchchat.distributed.logging_utils import SingletonLogger
9+
from torchchat.distributed.parallel_config import ParallelDims
10+
from torchchat.distributed.parallelize_llama import parallelize_llama
11+
from torchchat.distributed.utils import init_distributed
12+
from torchchat.distributed.world_maker import launch_distributed
File renamed without changes.

distributed/config_manager.py renamed to torchchat/distributed/config_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import torch
1414

15-
from distributed.logging_utils import SingletonLogger
15+
from torchchat.distributed.logging_utils import SingletonLogger
1616
logger = SingletonLogger.get_logger()
1717

1818

0 commit comments

Comments
 (0)