Skip to content

Commit 09ad83a

Browse files
emlinfacebook-github-bot
authored andcommitted
add virtual table eviction policy (#3172)
Summary: Pull Request resolved: #3172 X-link: facebookresearch/FBGEMM#1498 X-link: pytorch/FBGEMM#4433 Add eviction policy to embedding config and also enable config in mvai model family Reviewed By: duduyi2013, yixin94 Differential Revision: D75660955 fbshipit-source-id: e514f56a88b46f5000f8d54478531f7d4e739f21
1 parent 533f82b commit 09ad83a

File tree

9 files changed

+273
-4
lines changed

9 files changed

+273
-4
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import torch.distributed as dist
3232
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
3333
BackendType,
34+
EvictionPolicy,
3435
KVZCHParams,
3536
)
3637
from fbgemm_gpu.split_table_batched_embeddings_ops_inference import (
@@ -78,8 +79,14 @@
7879
)
7980
from torchrec.distributed.utils import append_prefix, none_throws
8081
from torchrec.modules.embedding_configs import (
82+
CountBasedEvictionPolicy,
83+
CountTimestampMixedEvictionPolicy,
8184
data_type_to_sparse_type,
85+
FeatureL2NormBasedEvictionPolicy,
86+
NoEvictionPolicy,
8287
pooling_type_to_pooling_mode,
88+
TimestampBasedEvictionPolicy,
89+
VirtualTableEvictionPolicy,
8390
)
8491
from torchrec.optim.fused import (
8592
EmptyFusedOptimizer,
@@ -201,6 +208,7 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]:
201208
def _populate_zero_collision_tbe_params(
202209
tbe_params: Dict[str, Any],
203210
sharded_local_buckets: List[Tuple[int, int, int]],
211+
config: GroupedEmbeddingConfig,
204212
) -> None:
205213
"""
206214
Construct Zero Collision TBE params from config and fused params dict.
@@ -211,10 +219,77 @@ def _populate_zero_collision_tbe_params(
211219
]
212220
bucket_sizes: List[int] = [size for _, _, size in sharded_local_buckets]
213221

222+
enabled = False
223+
for table in config.embedding_tables:
224+
if table.virtual_table_eviction_policy is not None and not isinstance(
225+
table.virtual_table_eviction_policy, NoEvictionPolicy
226+
):
227+
enabled = True
228+
if enabled:
229+
counter_thresholds = [0] * len(config.embedding_tables)
230+
ttls_in_mins = [0] * len(config.embedding_tables)
231+
counter_decay_rates = [0.0] * len(config.embedding_tables)
232+
l2_weight_thresholds = [0.0] * len(config.embedding_tables)
233+
eviction_strategy = -1
234+
table_names = [table.name for table in config.embedding_tables]
235+
for i, table in enumerate(config.embedding_tables):
236+
policy_t = table.virtual_table_eviction_policy
237+
if policy_t is not None:
238+
if isinstance(policy_t, CountBasedEvictionPolicy):
239+
counter_thresholds[i] = policy_t.eviction_threshold
240+
counter_decay_rates[i] = policy_t.decay_rate
241+
if eviction_strategy == -1 or eviction_strategy == 1:
242+
eviction_strategy = 1
243+
else:
244+
raise ValueError(
245+
f"Do not support multiple eviction strategy in one tbe {eviction_strategy} and 1 for tables {table_names}"
246+
)
247+
elif isinstance(policy_t, TimestampBasedEvictionPolicy):
248+
ttls_in_mins[i] = policy_t.eviction_ttl_mins
249+
if eviction_strategy == -1 or eviction_strategy == 0:
250+
eviction_strategy = 0
251+
else:
252+
raise ValueError(
253+
f"Do not support multiple eviction strategy in one tbe {eviction_strategy} and 0 for tables {table_names}"
254+
)
255+
elif isinstance(policy_t, FeatureL2NormBasedEvictionPolicy):
256+
l2_weight_thresholds[i] = policy_t.eviction_threshold
257+
if eviction_strategy == -1 or eviction_strategy == 3:
258+
eviction_strategy = 3
259+
else:
260+
raise ValueError(
261+
f"Do not support multiple eviction strategy in one tbe {eviction_strategy} and 3 for tables {table_names}"
262+
)
263+
elif isinstance(policy_t, CountTimestampMixedEvictionPolicy):
264+
counter_thresholds[i] = policy_t.eviction_threshold
265+
counter_decay_rates[i] = policy_t.decay_rate
266+
ttls_in_mins[i] = policy_t.eviction_ttl_mins
267+
if eviction_strategy == -1 or eviction_strategy == 2:
268+
eviction_strategy = 2
269+
else:
270+
raise ValueError(
271+
f"Do not support multiple eviction strategy in one tbe {eviction_strategy} and 2 for tables {table_names}"
272+
)
273+
else:
274+
raise ValueError(
275+
f"Unsupported eviction policy {policy_t} for table {table.name}"
276+
)
277+
eviction_policy = EvictionPolicy(
278+
eviction_trigger_mode=2, # 2 means mem_util based eviction
279+
eviction_strategy=eviction_strategy,
280+
counter_thresholds=counter_thresholds,
281+
ttls_in_mins=ttls_in_mins,
282+
counter_decay_rates=counter_decay_rates,
283+
l2_weight_thresholds=l2_weight_thresholds,
284+
)
285+
else:
286+
eviction_policy = None
287+
214288
tbe_params["kv_zch_params"] = KVZCHParams(
215289
bucket_offsets=bucket_offsets,
216290
bucket_sizes=bucket_sizes,
217291
enable_optimizer_offloading=False,
292+
eviction_policy=eviction_policy,
218293
)
219294

220295

@@ -1318,7 +1393,7 @@ def __init__(
13181393
self._config.embedding_tables, self._pg
13191394
)
13201395
)
1321-
_populate_zero_collision_tbe_params(ssd_tbe_params, self._bucket_spec)
1396+
_populate_zero_collision_tbe_params(ssd_tbe_params, self._bucket_spec, config)
13221397
compute_kernel = config.embedding_tables[0].compute_kernel
13231398
embedding_location = compute_kernel_to_embedding_location(compute_kernel)
13241399

@@ -2124,7 +2199,7 @@ def __init__(
21242199
self._config.embedding_tables, self._pg
21252200
)
21262201
)
2127-
_populate_zero_collision_tbe_params(ssd_tbe_params, self._bucket_spec)
2202+
_populate_zero_collision_tbe_params(ssd_tbe_params, self._bucket_spec, config)
21282203
compute_kernel = config.embedding_tables[0].compute_kernel
21292204
embedding_location = compute_kernel_to_embedding_location(compute_kernel)
21302205

torchrec/distributed/embedding.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ def create_sharding_infos_by_sharding_device_group(
248248
weight_init_min=config.weight_init_min,
249249
total_num_buckets=config.total_num_buckets,
250250
use_virtual_table=config.use_virtual_table,
251+
virtual_table_eviction_policy=config.virtual_table_eviction_policy,
251252
),
252253
param_sharding=parameter_sharding,
253254
param=param,
@@ -613,6 +614,7 @@ def create_grouped_sharding_infos(
613614
weight_init_min=config.weight_init_min,
614615
total_num_buckets=config.total_num_buckets,
615616
use_virtual_table=config.use_virtual_table,
617+
virtual_table_eviction_policy=config.virtual_table_eviction_policy,
616618
),
617619
param_sharding=parameter_sharding,
618620
param=param,

torchrec/distributed/embeddingbag.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,7 @@ def create_sharding_infos_by_sharding_device_group(
295295
),
296296
total_num_buckets=config.total_num_buckets,
297297
use_virtual_table=config.use_virtual_table,
298+
virtual_table_eviction_policy=config.virtual_table_eviction_policy,
298299
),
299300
param_sharding=parameter_sharding,
300301
param=param,
@@ -693,6 +694,7 @@ def create_grouped_sharding_infos(
693694
),
694695
total_num_buckets=config.total_num_buckets,
695696
use_virtual_table=config.use_virtual_table,
697+
virtual_table_eviction_policy=config.virtual_table_eviction_policy,
696698
),
697699
param_sharding=parameter_sharding,
698700
param=param,

torchrec/distributed/quant_state.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,9 @@ class WeightSpec:
385385
shard_offsets: List[int] # shard offsets
386386
shard_sizes: List[int] # shard sizes
387387
sharding_type: Optional[str] # e.g. ShardingType.ROW_WISE.value=="row_wise"
388+
virtual_table_dim_offsets: Optional[List[int]] = (
389+
None # for virtual table, weight dim offsets for quantization. e.g. [8, 264] for 256 dim tables, the first 8 elements are the metaheader
390+
)
388391

389392

390393
def get_bucket_offsets_per_virtual_table(
@@ -504,6 +507,18 @@ def sharded_tbes_weights_spec(
504507
tables = config.embedding_tables
505508
for table_idx, table in enumerate(tables):
506509
table_name: str = table.name
510+
table_dim_offsets: Optional[List[int]] = (
511+
None
512+
if not table.use_virtual_table
513+
else [0, table.embedding_dim]
514+
)
515+
if table.virtual_table_eviction_policy:
516+
table_dim_offsets = [
517+
table.virtual_table_eviction_policy.get_meta_header_len(),
518+
# pyre-ignore [16]
519+
table.virtual_table_eviction_policy.get_meta_header_len()
520+
+ table.embedding_dim,
521+
]
507522
# pyre-ignore
508523
table_metadata: ShardMetadata = table.local_metadata
509524
local_rows = table.local_rows
@@ -577,6 +592,7 @@ def sharded_tbes_weights_spec(
577592
shard_offsets=shard_offsets,
578593
shard_sizes=shard_sizes,
579594
sharding_type=sharding_type,
595+
virtual_table_dim_offsets=table_dim_offsets,
580596
)
581597

582598
# We also need to populate weight_id tensor for vritual

torchrec/distributed/sharding/rw_sharding.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ def _shard(
219219
num_embeddings_post_pruning=info.embedding_config.num_embeddings_post_pruning,
220220
total_num_buckets=info.embedding_config.total_num_buckets,
221221
use_virtual_table=info.embedding_config.use_virtual_table,
222+
virtual_table_eviction_policy=info.embedding_config.virtual_table_eviction_policy,
222223
)
223224
)
224225
return tables_per_rank

torchrec/distributed/test_utils/infer_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
dtype_to_data_type,
7878
EmbeddingBagConfig,
7979
QuantConfig,
80+
VirtualTableEvictionPolicy,
8081
)
8182
from torchrec.modules.embedding_modules import EmbeddingBagCollection
8283
from torchrec.modules.feature_processor_ import PositionWeightedModuleCollection
@@ -642,6 +643,7 @@ def create_test_model(
642643
constraints: Optional[Dict[str, ParameterConstraints]] = None,
643644
weight_dtype: torch.dtype = torch.qint8,
644645
pruning_dict: Optional[Dict[str, int]] = None,
646+
virtual_table_eviction_policy: Optional[VirtualTableEvictionPolicy] = None,
645647
) -> TestModelInfo:
646648
topology: Topology = Topology(
647649
world_size=world_size, compute_device=sparse_device.type
@@ -675,6 +677,8 @@ def create_test_model(
675677
embedding_dim=emb_dim,
676678
name="table_" + str(i),
677679
feature_names=["feature_" + str(i)],
680+
use_virtual_table=True if virtual_table_eviction_policy else False,
681+
virtual_table_eviction_policy=virtual_table_eviction_policy,
678682
)
679683
for i in range(mi.num_features)
680684
]
@@ -685,6 +689,8 @@ def create_test_model(
685689
embedding_dim=emb_dim,
686690
name="weighted_table_" + str(i),
687691
feature_names=["weighted_feature_" + str(i)],
692+
use_virtual_table=True if virtual_table_eviction_policy else False,
693+
virtual_table_eviction_policy=virtual_table_eviction_policy,
688694
)
689695
for i in range(mi.num_weighted_features)
690696
]

torchrec/distributed/tests/test_infer_shardings.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@
6969
from torchrec.distributed.test_utils.test_model import ModelInput
7070
from torchrec.distributed.types import ShardingEnv, ShardingPlan
7171
from torchrec.fx import symbolic_trace
72+
from torchrec.modules.embedding_configs import (
73+
dtype_to_data_type,
74+
TimestampBasedEvictionPolicy,
75+
)
7276
from torchrec.modules.embedding_modules import EmbeddingBagCollection
7377
from torchrec.modules.feature_processor_ import (
7478
FeatureProcessorsCollection,
@@ -357,6 +361,90 @@ def test_rw(self, weight_dtype: torch.dtype, device_type: str) -> None:
357361
ShardingType.ROW_WISE.value,
358362
)
359363

364+
@unittest.skipIf(
365+
torch.cuda.device_count() <= 1,
366+
"Not enough GPUs available",
367+
)
368+
# pyre-ignore
369+
@given(
370+
weight_dtype=st.sampled_from([torch.qint8, torch.quint4x2]),
371+
device_type=st.sampled_from(["cuda", "cpu"]),
372+
)
373+
@settings(max_examples=4, deadline=None)
374+
def test_rw_with_virtual_table_eviction(
375+
self, weight_dtype: torch.dtype, device_type: str
376+
) -> None:
377+
num_embeddings = 256
378+
emb_dim = 16
379+
world_size = 2
380+
batch_size = 4
381+
local_device = torch.device(f"{device_type}:0")
382+
eviction_policy = TimestampBasedEvictionPolicy()
383+
eviction_policy.init_metaheader_config(dtype_to_data_type(torch.float16))
384+
mi = create_test_model(
385+
num_embeddings,
386+
emb_dim,
387+
world_size,
388+
batch_size,
389+
dense_device=local_device,
390+
sparse_device=local_device,
391+
quant_state_dict_split_scale_bias=True,
392+
weight_dtype=weight_dtype,
393+
virtual_table_eviction_policy=eviction_policy,
394+
)
395+
396+
non_sharded_model = mi.quant_model
397+
num_emb_half = num_embeddings // 2
398+
expected_shards = [
399+
[
400+
((0, 0, num_emb_half, emb_dim), placement(device_type, 0, world_size)),
401+
(
402+
(num_emb_half, 0, num_emb_half, emb_dim),
403+
placement(device_type, 1, world_size),
404+
),
405+
]
406+
]
407+
sharded_model = shard_qebc(
408+
mi,
409+
sharding_type=ShardingType.ROW_WISE,
410+
device=local_device,
411+
expected_shards=expected_shards,
412+
)
413+
inputs = [
414+
model_input_to_forward_args(inp.to(local_device))
415+
for inp in prep_inputs(mi, world_size, batch_size, long_indices=False)
416+
]
417+
418+
sharded_model.load_state_dict(non_sharded_model.state_dict())
419+
420+
sharded_output = sharded_model(*inputs[0])
421+
non_sharded_output = non_sharded_model(*inputs[0])
422+
assert_close(sharded_output, non_sharded_output)
423+
424+
weights_spec: Dict[str, WeightSpec] = sharded_tbes_weights_spec(sharded_model)
425+
assert_weight_spec(
426+
weights_spec,
427+
expected_shards,
428+
"_module.sparse.ebc",
429+
"embedding_bags",
430+
["table_0"],
431+
ShardingType.ROW_WISE.value,
432+
)
433+
print(weights_spec)
434+
assert (
435+
weights_spec[
436+
"_module.sparse.ebc.tbes.0.0.table_0.weight"
437+
].virtual_table_dim_offsets
438+
is not None
439+
)
440+
assert (
441+
# pyre-ignore [16]
442+
weights_spec[
443+
"_module.sparse.ebc.tbes.0.0.table_0.weight"
444+
].virtual_table_dim_offsets[0]
445+
== 8
446+
)
447+
360448
@unittest.skipIf(
361449
torch.cuda.device_count() <= 1,
362450
"Not enough GPUs available",

0 commit comments

Comments
 (0)