Skip to content

Commit 76cf40f

Browse files
committed
[Tests] Add multi-GPU integration tests for DDP quantization
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
1 parent ac0cc2a commit 76cf40f

File tree

1 file changed

+89
-0
lines changed

1 file changed

+89
-0
lines changed
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
"""
2+
Run with: torchrun --nproc_per_node=2 -m pytest <this_file> -v
3+
"""
4+
5+
import os
6+
7+
import pytest
8+
import torch
9+
import torch.distributed as dist
10+
11+
from llmcompressor.utils.distributed import (
12+
all_reduce_max,
13+
all_reduce_min,
14+
get_rank,
15+
get_world_size,
16+
is_distributed,
17+
)
18+
from tests.testing_utils import requires_gpu
19+
20+
# initialize process group when running under torchrun
21+
if (
22+
os.environ.get("RANK") is not None
23+
and torch.cuda.is_available()
24+
and not dist.is_initialized()
25+
):
26+
dist.init_process_group(backend="nccl")
27+
torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0)))
28+
29+
30+
def _skip_if_not_distributed():
31+
if not is_distributed():
32+
pytest.skip("Requires torchrun --nproc_per_node=2")
33+
34+
35+
@pytest.mark.multi_gpu
36+
@requires_gpu(2)
37+
def test_all_reduce_min_max():
38+
_skip_if_not_distributed()
39+
rank = get_rank()
40+
41+
mins = (
42+
torch.tensor([1.0, 3.0], device="cuda")
43+
if rank == 0
44+
else torch.tensor([2.0, 1.0], device="cuda")
45+
)
46+
maxs = (
47+
torch.tensor([10.0, 20.0], device="cuda")
48+
if rank == 0
49+
else torch.tensor([15.0, 10.0], device="cuda")
50+
)
51+
52+
assert torch.equal(all_reduce_min(mins), torch.tensor([1.0, 1.0], device="cuda"))
53+
assert torch.equal(all_reduce_max(maxs), torch.tensor([15.0, 20.0], device="cuda"))
54+
55+
56+
@pytest.mark.multi_gpu
57+
@requires_gpu(2)
58+
def test_synced_qparams_are_identical_across_ranks():
59+
_skip_if_not_distributed()
60+
rank = get_rank()
61+
62+
from compressed_tensors.quantization import QuantizationArgs
63+
from compressed_tensors.quantization.utils import calculate_qparams
64+
65+
args = QuantizationArgs(num_bits=8, type="int", symmetric=True, strategy="tensor")
66+
67+
local_min = (
68+
torch.tensor([-2.0], device="cuda")
69+
if rank == 0
70+
else torch.tensor([-5.0], device="cuda")
71+
)
72+
local_max = (
73+
torch.tensor([3.0], device="cuda")
74+
if rank == 0
75+
else torch.tensor([1.0], device="cuda")
76+
)
77+
78+
global_min = all_reduce_min(local_min.clone())
79+
global_max = all_reduce_max(local_max.clone())
80+
81+
scale, _ = calculate_qparams(
82+
min_vals=global_min,
83+
max_vals=global_max,
84+
quantization_args=args,
85+
)
86+
87+
gathered = [torch.zeros_like(scale) for _ in range(get_world_size())]
88+
dist.all_gather(gathered, scale)
89+
assert torch.equal(gathered[0], gathered[1])

0 commit comments

Comments
 (0)