Skip to content

Commit 51078e8

Browse files
emlinfacebook-github-bot
authored andcommitted
propagate shard offsets for KV ZCH inference operator (#3178)
Summary: Pull Request resolved: #3178 populate ZCH v.Next sharding offset to inference operator during publish, this offset will be used during weight loading in inference side. Reviewed By: chenyuzhcy Differential Revision: D77989209 fbshipit-source-id: 95fdb750e109dc17eaeea264133437100819ed60
1 parent 09ad83a commit 51078e8

File tree

7 files changed

+193
-7
lines changed

7 files changed

+193
-7
lines changed

torchrec/distributed/embedding_lookup.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,10 +1169,23 @@ def __init__(
11691169
self._is_empty_rank: List[bool] = []
11701170
for rank in range(world_size):
11711171
empty_rank = len(grouped_configs_per_rank[rank]) == 0
1172-
# Propagate shard index to get the correct runtime_device based on shard metadata
1173-
# in case of heterogenous sharding of a single table across different device types
1172+
grouped_configs_per_rank_elem = grouped_configs_per_rank[rank]
1173+
contains_virtual_table = any(
1174+
config.is_using_virtual_table()
1175+
for config in grouped_configs_per_rank_elem
1176+
)
1177+
# In case of heterogenous sharding of a single table acorss
1178+
# different device types i.e. when device_type_from_sharding_infos
1179+
# is a tuple OR if any of the table is virtual table, we can for
1180+
# now assume that the table is row_wise sharded and the shard_index
1181+
# can be set to the rank. shard_index is used downstream to get
1182+
# runtime_device (or row alignment) as well as to get the shard
1183+
# offsets for virtual table
11741184
shard_index = (
1175-
rank if isinstance(device_type_from_sharding_infos, tuple) else None
1185+
rank
1186+
if isinstance(device_type_from_sharding_infos, tuple)
1187+
or contains_virtual_table
1188+
else None
11761189
)
11771190
self._is_empty_rank.append(empty_rank)
11781191
if not empty_rank:
@@ -1235,10 +1248,23 @@ def __init__(
12351248
"meta" if device is not None and device.type == "meta" else "cuda"
12361249
)
12371250
for rank in range(world_size):
1238-
# propagate shard index to get the correct runtime_device based on shard metadata
1239-
# in case of heterogenous sharding of a single table acorss different device types
1251+
grouped_configs_per_rank_elem = grouped_configs_per_rank[rank]
1252+
contains_virtual_table = any(
1253+
config.is_using_virtual_table()
1254+
for config in grouped_configs_per_rank_elem
1255+
)
1256+
# In case of heterogenous sharding of a single table acorss
1257+
# different device types i.e. when device_type_from_sharding_infos
1258+
# is a tuple OR if any of the table is virtual table, we can for
1259+
# now assume that the table is row_wise sharded and the shard_index
1260+
# can be set to the rank. shard_index is used downstream to get
1261+
# runtime_device (or row alignment) as well as to get the shard
1262+
# offsets for virtual table
12401263
shard_index = (
1241-
rank if isinstance(device_type_from_sharding_infos, tuple) else None
1264+
rank
1265+
if isinstance(device_type_from_sharding_infos, tuple)
1266+
or contains_virtual_table
1267+
else None
12421268
)
12431269
device = rank_device(device_type, rank)
12441270
self._embedding_lookups_per_rank.append(

torchrec/distributed/embedding_sharding.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,7 @@ def _group_tables_per_rank(
564564
table.data_type,
565565
),
566566
_prefetch_and_cached(table),
567+
table.use_virtual_table if is_inference else None,
567568
)
568569
# micromanage the order of we traverse the groups to ensure backwards compatibility
569570
if grouping_key not in groups:
@@ -579,6 +580,7 @@ def _group_tables_per_rank(
579580
compute_kernel_type,
580581
_,
581582
_,
583+
use_virtual_table,
582584
) = grouping_key
583585
grouped_tables = groups[grouping_key]
584586
# remove non-native fused params

torchrec/distributed/embedding_types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,9 @@ def embedding_shard_metadata(self) -> List[Optional[ShardMetadata]]:
301301
embedding_shard_metadata.append(table.local_metadata)
302302
return embedding_shard_metadata
303303

304+
def is_using_virtual_table(self) -> bool:
305+
return any(table.use_virtual_table for table in self.embedding_tables)
306+
304307

305308
F = TypeVar("F", bound=Multistreamable)
306309
T = TypeVar("T")

torchrec/distributed/quant_embedding_kernel.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,30 @@ def _quantize_weight(
120120
return quant_weight_list
121121

122122

123+
def _get_shard_offsets_for_kv_zch(
124+
config: GroupedEmbeddingConfig,
125+
shard_index: int,
126+
) -> List[int]:
127+
"""
128+
Given kv zch tables are rw sharded, getting the row offsets for each shard
129+
at level to be used witin kv zch look up kernel
130+
"""
131+
shard_row_offsets = []
132+
for table in config.embedding_tables:
133+
assert (
134+
table.global_metadata is not None
135+
), f"Expected global_metadata to be populated for table {table.name} to get shard offsets for kv zch look up kernel"
136+
assert (
137+
len(table.global_metadata.shards_metadata) > shard_index
138+
), f"Expected table {table.name} to have more shards than shard index {shard_index}. Found {len(table.global_metadata.shards_metadata)} shards"
139+
shard_row_offsets.append(
140+
# pyre-ignore: Undefined attribute [16]
141+
table.global_metadata.shards_metadata[shard_index].shard_offsets[0]
142+
)
143+
logger.info(f"Shard row offsets for kv zch look up table: {shard_row_offsets=}")
144+
return shard_row_offsets
145+
146+
123147
def _get_runtime_device(
124148
device: Optional[torch.device],
125149
config: GroupedEmbeddingConfig,
@@ -293,6 +317,16 @@ def __init__(
293317
else:
294318
tbe_clazz = IntNBitTableBatchedEmbeddingBagsCodegen
295319

320+
if is_virtual_table:
321+
assert (
322+
shard_index is not None and shard_index >= 0
323+
), "valid shard_index must be provided for kv zch batch embedding to compute shard offsets"
324+
shard_offsets_for_kv_zch = _get_shard_offsets_for_kv_zch(
325+
config, shard_index
326+
)
327+
else:
328+
shard_offsets_for_kv_zch = None
329+
296330
self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = tbe_clazz(
297331
embedding_specs=embedding_specs,
298332
device=device,
@@ -310,6 +344,12 @@ def __init__(
310344
)
311345
if device is not None:
312346
self._emb_module.initialize_weights()
347+
if shard_offsets_for_kv_zch is not None:
348+
assert (
349+
tbe_clazz == KVEmbeddingInference
350+
), "shard_offsets_for_kv_zch should be computed only for kv zch kernel"
351+
# pyre-ignore: Call error [29]
352+
self._emb_module.init_tbe_config(shard_offsets_for_kv_zch)
313353

314354
def init_parameters(self) -> None:
315355
pass
@@ -479,6 +519,16 @@ def __init__(
479519
if is_virtual_table
480520
else IntNBitTableBatchedEmbeddingBagsCodegen
481521
)
522+
if is_virtual_table:
523+
assert (
524+
shard_index is not None and shard_index >= 0
525+
), "valid shard_index must be provided for kv zch batch embedding to compute shard offsets"
526+
shard_offsets_for_kv_zch = _get_shard_offsets_for_kv_zch(
527+
config, shard_index
528+
)
529+
else:
530+
shard_offsets_for_kv_zch = None
531+
482532
self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = embedding_clazz(
483533
embedding_specs=[
484534
(
@@ -511,6 +561,12 @@ def __init__(
511561
)
512562
if device is not None:
513563
self._emb_module.initialize_weights()
564+
if shard_offsets_for_kv_zch is not None:
565+
assert (
566+
embedding_clazz == KVEmbeddingInference
567+
), "shard_offsets_for_kv_zch should be computed only for kv zch kernel"
568+
# pyre-ignore: Call error [29]
569+
self._emb_module.init_tbe_config(shard_offsets_for_kv_zch)
514570

515571
@property
516572
def emb_module(

torchrec/distributed/sharding/tw_sharding.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def _shard(
184184
weight_init_min=info.embedding_config.weight_init_min,
185185
fused_params=info.fused_params,
186186
num_embeddings_post_pruning=info.embedding_config.num_embeddings_post_pruning,
187+
use_virtual_table=info.embedding_config.use_virtual_table,
187188
)
188189
)
189190
return tables_per_rank

torchrec/distributed/sharding/twrw_sharding.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ def _shard(
204204
weight_init_max=info.embedding_config.weight_init_max,
205205
weight_init_min=info.embedding_config.weight_init_min,
206206
fused_params=info.fused_params,
207+
use_virtual_table=info.embedding_config.use_virtual_table,
207208
)
208209
)
209210

torchrec/distributed/tests/test_quant_sequence_model_parallel.py

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313

1414
import hypothesis.strategies as st
1515
import torch
16+
from fbgemm_gpu.split_table_batched_embeddings_ops_inference import (
17+
IntNBitTableBatchedEmbeddingBagsCodegen,
18+
)
19+
from fbgemm_gpu.tbe.cache.kv_embedding_ops_inference import KVEmbeddingInference
1620
from hypothesis import given, settings, Verbosity
1721
from torch import nn, quantization as quant
1822
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
@@ -24,7 +28,7 @@
2428
)
2529
from torchrec.distributed.tests.test_sequence_model import TestSequenceSparseNN
2630
from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType
27-
from torchrec.modules.embedding_configs import EmbeddingConfig
31+
from torchrec.modules.embedding_configs import EmbeddingConfig, NoEvictionPolicy
2832
from torchrec.modules.embedding_modules import EmbeddingCollection
2933
from torchrec.quant.embedding_modules import (
3034
EmbeddingCollection as QuantEmbeddingCollection,
@@ -203,3 +207,96 @@ def test_quant_pred_shard(
203207
)
204208
local_batch = local_batch.to(device)
205209
sharded_quant_model(local_batch.idlist_features)
210+
211+
# pyre-fixme[56]
212+
@unittest.skipIf(
213+
torch.cuda.device_count() <= 1,
214+
"Not enough GPUs available",
215+
)
216+
@settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None)
217+
def test_sharded_quant_kv_zch(self) -> None:
218+
device = torch.device("cuda:0")
219+
num_features = 4
220+
221+
tables = [
222+
EmbeddingConfig(
223+
num_embeddings=(i + 1) * 11,
224+
embedding_dim=16,
225+
name="table_" + str(i),
226+
feature_names=["feature_" + str(i)],
227+
use_virtual_table=True if i % 2 == 0 else False,
228+
virtual_table_eviction_policy=(
229+
NoEvictionPolicy() if i % 2 == 0 else None
230+
),
231+
)
232+
for i in range(num_features)
233+
]
234+
# wrap in sequential because _quantize only applies to submodules...
235+
model = nn.Sequential(EmbeddingCollection(tables=tables, device=device))
236+
237+
quant_model = _quantize(model, quant_state_dict_split_scale_bias=True)
238+
239+
sharded_quant_model = _shard_modules(
240+
module=quant_model,
241+
sharders=[
242+
cast(
243+
ModuleSharder[torch.nn.Module],
244+
TestQuantECSharder(
245+
sharding_type=ShardingType.ROW_WISE.value,
246+
kernel_type=EmbeddingComputeKernel.QUANT.value,
247+
),
248+
)
249+
],
250+
device=device,
251+
env=ShardingEnv.from_local(world_size=2, rank=0),
252+
)
253+
254+
sharded_quant_model.load_state_dict(sharded_quant_model.state_dict())
255+
256+
local_batch, _ = ModelInput.generate(
257+
batch_size=16,
258+
world_size=1,
259+
num_float_features=10,
260+
tables=self.tables,
261+
weighted_tables=[],
262+
indices_dtype=torch.int32,
263+
lengths_dtype=torch.int32,
264+
)
265+
local_batch = local_batch.to(device)
266+
sharded_quant_model(local_batch.idlist_features)
267+
self.assertIsInstance(
268+
# pyre-ignore [29]
269+
sharded_quant_model[0]
270+
._lookups[0]
271+
._embedding_lookups_per_rank[0]
272+
._emb_modules[0]
273+
._emb_module,
274+
KVEmbeddingInference,
275+
)
276+
self.assertIsInstance(
277+
# pyre-ignore [29]
278+
sharded_quant_model[0]
279+
._lookups[0]
280+
._embedding_lookups_per_rank[0]
281+
._emb_modules[1]
282+
._emb_module,
283+
IntNBitTableBatchedEmbeddingBagsCodegen,
284+
)
285+
self.assertEqual(
286+
# pyre-ignore [29]
287+
sharded_quant_model[0]
288+
._lookups[0]
289+
._embedding_lookups_per_rank[0]
290+
._emb_modules[0]
291+
._emb_module.table_sharding_offset,
292+
[0, 0],
293+
)
294+
self.assertEqual(
295+
# pyre-ignore [29]
296+
sharded_quant_model[0]
297+
._lookups[0]
298+
._embedding_lookups_per_rank[1]
299+
._emb_modules[0]
300+
._emb_module.table_sharding_offset,
301+
[6, 17],
302+
)

0 commit comments

Comments
 (0)