Skip to content

Commit fe97ff1

Browse files
yxsamliufacebook-github-bot
authored andcommitted
Khanin/merged pool emb opt (pytorch#4977)
Summary: X-link: facebookresearch/FBGEMM#1998 implemented pitch size on tensor allocation for better memory alignment. Differential Revision: D83994225 Pulled By: q10
1 parent 88dc834 commit fe97ff1

File tree

2 files changed

+103
-53
lines changed

2 files changed

+103
-53
lines changed

fbgemm_gpu/bench/merge_embeddings_benchmark.py

Lines changed: 62 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import numpy as np
1717
import tabulate
1818
import torch
19-
2019
from fbgemm_gpu.split_embedding_configs import SparseType
2120
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
2221
BoundsCheckMode,
@@ -99,46 +98,72 @@ def generate_requests(
9998
return rs
10099

101100

102-
# pyre-fixme[3]: Return type must be annotated.
103101
def _get_random_tensor(
104102
num_ads: int,
105103
embedding_dimension: int,
106104
ads_tables: int,
107105
data_type: str,
108106
gpu_idx: int,
109107
include_quantization: bool,
110-
):
108+
use_pitched: bool,
109+
alignment: int = 256, # alignment in bytes
110+
) -> torch.Tensor:
111+
device = torch.device(f"cuda:{gpu_idx}")
112+
111113
if data_type == "FP16" or include_quantization:
112-
result_tensor = torch.randn(
113-
num_ads,
114-
embedding_dimension * ads_tables,
115-
dtype=torch.float16,
116-
device=torch.device(f"cuda:{gpu_idx}"),
117-
)
114+
dtype = torch.float16
115+
width_elems = embedding_dimension * ads_tables
116+
elem_size = torch.finfo(dtype).bits // 8
117+
118+
if use_pitched:
119+
width_bytes = width_elems * elem_size
120+
pitch_bytes = int(np.ceil(width_bytes / alignment) * alignment)
121+
pitch_elems = pitch_bytes // elem_size
122+
storage = torch.empty((num_ads, pitch_elems), dtype=dtype, device=device)
123+
result_tensor = storage[:, :width_elems] # logical view
124+
else:
125+
result_tensor = torch.randn(
126+
num_ads, width_elems, dtype=dtype, device=device
127+
)
128+
118129
elif data_type == "INT8":
119-
assert (
120-
embedding_dimension % 2
121-
) == 0, "needs to align to 2 bytes (half type size) for INT8"
122-
result_tensor = torch.randint(
123-
0,
124-
255,
125-
# 2 FP16 numbers for scale and bias, total of 4 bytes overhead
126-
size=(num_ads, (embedding_dimension + 4) * ads_tables),
127-
dtype=torch.uint8,
128-
device=torch.device(f"cuda:{gpu_idx}"),
129-
)
130+
assert embedding_dimension % 2 == 0, "needs to align to 2 bytes for INT8"
131+
dtype = torch.uint8
132+
width_elems = (embedding_dimension + 4) * ads_tables
133+
elem_size = 1
134+
135+
if use_pitched:
136+
width_bytes = width_elems * elem_size
137+
pitch_bytes = int(np.ceil(width_bytes / alignment) * alignment)
138+
pitch_elems = pitch_bytes // elem_size
139+
storage = torch.randint(
140+
0, 255, (num_ads, pitch_elems), dtype=dtype, device=device
141+
)
142+
result_tensor = storage[:, :width_elems]
143+
else:
144+
result_tensor = torch.randint(
145+
0, 255, (num_ads, width_elems), dtype=dtype, device=device
146+
)
147+
130148
elif data_type == "INT4":
131-
assert (
132-
embedding_dimension % 4
133-
) == 0, "needs to align to 2 bytes (half type size) for INT4"
134-
result_tensor = torch.randint(
135-
0,
136-
255,
137-
# Using torch.uint8 for int4 storage
138-
size=(num_ads, (embedding_dimension // 2 + 4) * ads_tables),
139-
dtype=torch.uint8,
140-
device=torch.device(f"cuda:{gpu_idx}"),
141-
)
149+
assert embedding_dimension % 4 == 0, "needs to align to 2 bytes for INT4"
150+
dtype = torch.uint8
151+
width_elems = (embedding_dimension // 2 + 4) * ads_tables
152+
elem_size = 1
153+
154+
if use_pitched:
155+
width_bytes = width_elems * elem_size
156+
pitch_bytes = int(np.ceil(width_bytes / alignment) * alignment)
157+
pitch_elems = pitch_bytes // elem_size
158+
storage = torch.randint(
159+
0, 255, (num_ads, pitch_elems), dtype=dtype, device=device
160+
)
161+
result_tensor = storage[:, :width_elems]
162+
else:
163+
result_tensor = torch.randint(
164+
0, 255, (num_ads, width_elems), dtype=dtype, device=device
165+
)
166+
142167
else:
143168
raise ValueError
144169

@@ -253,6 +278,7 @@ def benchmark( # noqa C901
253278
num_ads: int,
254279
embedding_dimension: int,
255280
ads_tables: int,
281+
use_pitched: bool,
256282
iters: int = 10,
257283
p2p_bw: bool = False,
258284
dst_device: int = 0,
@@ -298,6 +324,7 @@ def benchmark( # noqa C901
298324
data_type,
299325
gpu_idx,
300326
include_quantization,
327+
use_pitched,
301328
)
302329
for gpu_idx in range(num_gpus)
303330
]
@@ -485,6 +512,7 @@ def pool_func_with_quantization(
485512
@click.option("--num_of_embeddings", default=100000, type=int)
486513
@click.option("--pooling_factor", default=25, type=int)
487514
@click.option("--sweep", is_flag=True, default=False)
515+
@click.option("--use_pitched", is_flag=True, default=False)
488516
def cli(
489517
all_to_one_only: bool,
490518
sum_reduce_to_one_only: bool,
@@ -500,6 +528,7 @@ def cli(
500528
num_of_embeddings: int,
501529
pooling_factor: int,
502530
sweep: bool,
531+
use_pitched: bool,
503532
) -> None:
504533
csv_header = (
505534
"mode, data_type, num_ads, embedding_dimension, ads_tables, num_gpus, dst_device, all_to_one_only, "
@@ -534,6 +563,7 @@ def handler(signum, frame):
534563
num_ads,
535564
embedding_dimension,
536565
ads_tables,
566+
use_pitched,
537567
iters,
538568
p2p_bw,
539569
dst_device,
@@ -558,6 +588,7 @@ def handler(signum, frame):
558588
num_ads,
559589
embedding_dimension,
560590
ads_tables,
591+
use_pitched,
561592
iters,
562593
p2p_bw,
563594
dst_device,

fbgemm_gpu/test/merge_pooled_embeddings_test.py

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import fbgemm_gpu
1414

1515
import hypothesis.strategies as st
16+
import numpy as np
1617
import torch
1718
from hypothesis import given, settings, Verbosity
1819

@@ -32,8 +33,29 @@
3233
typed_gpu_unavailable: tuple[bool, str] = gpu_unavailable
3334

3435

35-
@unittest.skipIf(*gpu_unavailable)
36-
@unittest.skipIf(open_source, "Not supported in open source yet")
36+
def make_pitched_tensor(
37+
height: int,
38+
width: int,
39+
dtype: torch.dtype,
40+
# pyre-fixme[2]: Parameter must be annotated.
41+
device,
42+
alignment: int = 256,
43+
) -> torch.Tensor:
44+
elem_size = (
45+
torch.finfo(dtype).bits // 8
46+
if dtype.is_floating_point
47+
else torch.iinfo(dtype).bits // 8
48+
)
49+
width_bytes = width * elem_size
50+
pitch_bytes = int(np.ceil(width_bytes / alignment) * alignment)
51+
pitch_elems = pitch_bytes // elem_size
52+
storage = torch.randn((height, pitch_elems), dtype=dtype, device=device)
53+
view = storage[:, :width] # logical shape
54+
return view.contiguous() if alignment == 0 else view # return pitched view
55+
56+
57+
# @unittest.skipIf(open_source, "Not supported in open source yet")
58+
@unittest.skipIf(*typed_gpu_unavailable)
3759
class MergePooledEmbeddingsTest(unittest.TestCase):
3860
# pyre-fixme[56]: Pyre was not able to infer the type of argument
3961
# `hypothesis.strategies.integers($parameter$min_value = 1, $parameter$max_value =
@@ -51,16 +73,11 @@ class MergePooledEmbeddingsTest(unittest.TestCase):
5173
@settings(verbosity=Verbosity.verbose, max_examples=40, deadline=None)
5274
def test_merge(
5375
self,
54-
# pyre-fixme[2]: Parameter must be annotated.
55-
num_ads,
56-
# pyre-fixme[2]: Parameter must be annotated.
57-
embedding_dimension,
58-
# pyre-fixme[2]: Parameter must be annotated.
59-
ads_tables,
60-
# pyre-fixme[2]: Parameter must be annotated.
61-
num_gpus,
62-
# pyre-fixme[2]: Parameter must be annotated.
63-
non_default_stream,
76+
num_ads: int,
77+
embedding_dimension: int,
78+
ads_tables: int,
79+
num_gpus: int,
80+
non_default_stream: bool,
6481
# pyre-fixme[2]: Parameter must be annotated.
6582
r,
6683
dim: int,
@@ -107,27 +124,32 @@ def ref(pooled_ad_embeddings, batch_indices):
107124
torch.testing.assert_close(output_ref, output_cpu)
108125

109126
# pyre-fixme[56]: Pyre was not able to infer the type of argument
110-
# `hypothesis.strategies.integers($parameter$min_value = 1, $parameter$max_value =
111-
# 10)` to decorator factory `hypothesis.given`.
112127
@given(
113128
num_inputs=st.integers(min_value=1, max_value=10),
114129
num_gpus=st.integers(min_value=1, max_value=torch.cuda.device_count()),
115130
r=st.randoms(use_true_random=False),
131+
use_pitched=st.booleans(),
116132
)
117133
# Can instantiate 8 contexts which takes a long time.
118134
@settings(verbosity=Verbosity.verbose, max_examples=40, deadline=None)
119135
def test_all_to_one_device(
120136
self,
121-
# pyre-fixme[2]: Parameter must be annotated.
122-
num_inputs,
123-
# pyre-fixme[2]: Parameter must be annotated.
124-
num_gpus,
137+
num_inputs: int,
138+
num_gpus: int,
125139
# pyre-fixme[2]: Parameter must be annotated.
126140
r,
141+
use_pitched: bool,
127142
) -> None:
128143
dst_device = torch.device(f"cuda:{r.randint(0, num_gpus - 1)}")
129144
with torch.cuda.device(dst_device):
130-
inputs = [torch.randn(10, 20) for _ in range(num_inputs)]
145+
if use_pitched:
146+
inputs = [
147+
make_pitched_tensor(10, 20, torch.float32, "cpu", alignment=256)
148+
for _ in range(num_inputs)
149+
]
150+
else:
151+
inputs = [torch.randn(10, 20) for _ in range(num_inputs)]
152+
131153
cuda_inputs = [
132154
input.to(f"cuda:{i % num_gpus}") for i, input in enumerate(inputs)
133155
]
@@ -150,8 +172,6 @@ def test_merge_pooled_embeddings_gpu_to_cpu(self) -> None:
150172
torch.testing.assert_close(output, ref_output)
151173

152174
# pyre-fixme[56]: Pyre was not able to infer the type of argument
153-
# `hypothesis.strategies.integers($parameter$min_value = 1, $parameter$max_value =
154-
# 10)` to decorator factory `hypothesis.given`.
155175
@given(
156176
num_inputs=st.integers(min_value=1, max_value=8),
157177
num_gpus=st.integers(min_value=1, max_value=torch.cuda.device_count()),
@@ -234,7 +254,6 @@ def test_sum_reduce_to_one(
234254
cuda_output.cpu(), torch.stack(inputs).sum(dim=0)
235255
)
236256

237-
@unittest.skipIf(*typed_gpu_unavailable)
238257
def test_merge_pooled_embeddings_meta(self) -> None:
239258
"""
240259
Test that merge_pooled_embeddings works with meta tensor and

0 commit comments

Comments
 (0)