Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit daf902c

Browse files
authored
Merge branch 'main' into refactor/distributed_inference_without_abstraction
2 parents e07b03d + fd1857a commit daf902c

File tree

4 files changed

+57
-40
lines changed

4 files changed

+57
-40
lines changed

install/.pins/torchao-pin.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
7d7c14e898eca3fe66138d2a9445755a9270b800
1+
2e032c6b0de960dee554dcb08126ace718b14c6d

install/install_requirements.sh

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,31 +44,20 @@ fi
4444

4545
echo "Using pip executable: $PIP_EXECUTABLE"
4646

47-
#
48-
# First install requirements in install/requirements.txt. Older torch may be
49-
# installed from the dependency of other models. It will be overridden by
50-
# newer version of torch nightly installed later in this script.
51-
#
52-
53-
(
54-
set -x
55-
$PIP_EXECUTABLE install -r install/requirements.txt --extra-index-url https://download.pytorch.org/whl/nightly/cu121
56-
)
57-
5847
# Since torchchat often uses main-branch features of pytorch, only the nightly
5948
# pip versions will have the required features. The PYTORCH_NIGHTLY_VERSION value should
6049
# agree with the third-party/pytorch pinned submodule commit.
6150
#
6251
# NOTE: If a newly-fetched version of the executorch repo changes the value of
6352
# PYTORCH_NIGHTLY_VERSION, you should re-run this script to install the necessary
6453
# package versions.
65-
PYTORCH_NIGHTLY_VERSION=dev20241213
54+
PYTORCH_NIGHTLY_VERSION=dev20241218
6655

6756
# Nightly version for torchvision
68-
VISION_NIGHTLY_VERSION=dev20241213
57+
VISION_NIGHTLY_VERSION=dev20241218
6958

7059
# Nightly version for torchtune
71-
TUNE_NIGHTLY_VERSION=dev20241126
60+
TUNE_NIGHTLY_VERSION=dev20241218
7261

7362
# Uninstall triton, as nightly will depend on pytorch-triton, which is one and the same
7463
(
@@ -96,6 +85,16 @@ REQUIREMENTS_TO_INSTALL=(
9685
torchtune=="0.5.0.${TUNE_NIGHTLY_VERSION}"
9786
)
9887

88+
#
89+
# First install requirements in install/requirements.txt. Older torch may be
90+
# installed from the dependency of other models. It will be overridden by
91+
# newer version of torch nightly installed later in this script.
92+
#
93+
(
94+
set -x
95+
$PIP_EXECUTABLE install -r install/requirements.txt --extra-index-url "${TORCH_NIGHTLY_URL}"
96+
)
97+
9998
# Install the requirements. --extra-index-url tells pip to look for package
10099
# versions on the provided URL if they aren't available on the default URL.
101100
(

torchchat/utils/gguf_loader.py

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
pack_scales_and_zeros,
2525
)
2626

27+
from torchao.dtypes.utils import is_device
28+
2729

2830
logger: logging.Logger = logging.getLogger(__name__)
2931

@@ -128,6 +130,7 @@ def linear_int4(input, weight_int4pack, scales_and_zeros, out_features, groupsiz
128130
groupsize,
129131
scales_and_zeros,
130132
)
133+
131134
new_shape = origin_input_size[:-1] + (out_features,)
132135
c = c.reshape(new_shape)
133136
return c
@@ -178,16 +181,27 @@ def __init__(
178181
), "must specify both weights and scales_and_zeros, or neither"
179182

180183
if weight is None:
181-
weight = torch.empty(
182-
(
183-
out_features // 8,
184-
in_features // (inner_k_tiles * 16),
185-
32,
186-
inner_k_tiles // 2,
187-
),
188-
dtype=torch.int32,
189-
device=device,
190-
)
184+
if is_device(device, "cpu"):
185+
weight = torch.empty(
186+
(
187+
out_features,
188+
in_features // 2,
189+
),
190+
dtype=torch.uint8,
191+
device=device,
192+
)
193+
else:
194+
weight = torch.empty(
195+
(
196+
out_features // 8,
197+
in_features // (inner_k_tiles * 16),
198+
32,
199+
inner_k_tiles // 2,
200+
),
201+
dtype=torch.int32,
202+
device=device,
203+
)
204+
191205
scales_and_zeros = torch.empty(
192206
(in_features // groupsize, out_features, 2),
193207
dtype=get_precision(),
@@ -223,12 +237,17 @@ def _prepare_weight_and_scales_and_zeros(
223237
weight_int32, scales_and_zeros = group_quantize_tensor(
224238
weight_bf16, n_bit=4, groupsize=groupsize
225239
)
226-
weight_uint8 = (weight_int32[::, ::2] << 4 | weight_int32[::, 1::2]).to(
227-
torch.uint8
228-
)
229-
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
230-
weight_uint8, inner_k_tiles
231-
)
240+
if is_device(weight_int32.device.type, "cpu"):
241+
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
242+
weight_int32, inner_k_tiles
243+
)
244+
else:
245+
weight_uint8 = (weight_int32[::, ::2] << 4 | weight_int32[::, 1::2]).to(
246+
torch.uint8
247+
)
248+
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
249+
weight_uint8, inner_k_tiles
250+
)
232251
return weight_int4pack, scales_and_zeros
233252

234253
@classmethod
@@ -609,17 +628,14 @@ def load_model_and_state_dict(
609628
if load_state_dict:
610629
q, s, z = Q4_0.unpack(t)
611630
scales_and_zeros = pack_scales_and_zeros(s, z)
612-
q_uint8 = (q[::, ::2] << 4 | q[::, 1::2]).to(torch.uint8)
613-
614-
if torch.device(device).type == "cpu":
615-
weight_int4pack = (
616-
torch.ops.aten._convert_weight_to_int4pack_for_cpu(
617-
q, inner_k_tiles
618-
)
631+
if is_device(q.device.type, "cpu"):
632+
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
633+
q, inner_k_tiles
619634
)
620635
else:
636+
q_tmp = (q[::, ::2] << 4 | q[::, 1::2]).to(torch.uint8)
621637
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
622-
q_uint8, inner_k_tiles
638+
q_tmp, inner_k_tiles
623639
)
624640
state_dict[f"{fqn}.weight"] = weight_int4pack
625641
state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros
@@ -632,7 +648,7 @@ def load_model_and_state_dict(
632648
in_features=in_features,
633649
out_features=out_features,
634650
bias=False,
635-
device="meta",
651+
device="cpu",
636652
groupsize=Q4_0.groupsize,
637653
inner_k_tiles=inner_k_tiles,
638654
),

torchchat/utils/quantize.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -932,13 +932,15 @@ def quantized_model(self) -> nn.Module:
932932
libs = glob.glob(f"{torchao_build_path}/cmake-out/lib/libtorchao_ops_aten.*")
933933
libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs))
934934
torch.ops.load_library(libs[0])
935+
print("Loaded torchao cpu ops.")
935936
except Exception as e:
936937
print("Unabled to load torchao cpu ops library. Slow fallback kernels will be used.")
937938

938939
try:
939940
libname = "libtorchao_ops_mps_aten.dylib"
940941
libpath = f"{torchao_build_path}/cmake-out/lib/{libname}"
941942
torch.ops.load_library(libpath)
943+
print("Loaded torchao mps ops.")
942944
except Exception as e:
943945
print("Unabled to load torchao mps ops library.")
944946

0 commit comments

Comments
 (0)