Skip to content

Commit c23f388

Browse files
committed
Unit test of counter table's checkpoint.
1 parent 1c287be commit c23f388

File tree

5 files changed

+140
-39
lines changed

5 files changed

+140
-39
lines changed

corelib/dynamicemb/dynamicemb/batched_dynamicemb_tables.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,10 @@ def find_files(root_path: str, table_name: str, suffix: str) -> Tuple[List[str],
141141
"emb_values": partial(encode_checkpoint_file_path, item="values"),
142142
"emb_scores": partial(encode_checkpoint_file_path, item="scores"),
143143
"opt_values": partial(encode_checkpoint_file_path, item="opt_values"),
144+
"counter_keys": partial(encode_counter_checkpoint_file_path, item="keys"),
145+
"counter_frequencies": partial(
146+
encode_counter_checkpoint_file_path, item="frequencies"
147+
),
144148
}
145149
if suffix not in suffix_to_encode_file_path_func:
146150
raise RuntimeError(f"Invalid suffix: {suffix}")
@@ -1232,6 +1236,7 @@ def dump(
12321236
self,
12331237
save_dir: str,
12341238
optim: bool = False,
1239+
counter: bool = False,
12351240
table_names: Optional[List[str]] = None,
12361241
pg: Optional[dist.ProcessGroup] = None,
12371242
) -> None:
@@ -1245,7 +1250,7 @@ def dump(
12451250
world_size = dist.get_world_size(group=pg)
12461251

12471252
self.flush()
1248-
for table_name, storage, counter in zip(
1253+
for table_name, storage, counter_table in zip(
12491254
self._table_names, self._storages, self._admission_counter
12501255
):
12511256
if table_name not in set(table_names):
@@ -1278,20 +1283,28 @@ def dump(
12781283
include_meta=(rank == 0),
12791284
)
12801285

1286+
if not counter:
1287+
continue
1288+
12811289
counter_key_path = encode_counter_checkpoint_file_path(
12821290
save_dir, table_name, rank, world_size, "keys"
12831291
)
12841292
counter_frequency_path = encode_counter_checkpoint_file_path(
12851293
save_dir, table_name, rank, world_size, "frequencies"
12861294
)
12871295

1288-
if counter is not None:
1289-
counter.dump(counter_key_path, counter_frequency_path)
1296+
if counter_table is not None:
1297+
counter_table.dump(counter_key_path, counter_frequency_path)
1298+
else:
1299+
warnings.warn(
1300+
f"Counter table is none and will not dump it for table: {table_name}"
1301+
)
12901302

12911303
def load(
12921304
self,
12931305
save_dir: str,
12941306
optim: bool = False,
1307+
counter: bool = False,
12951308
table_names: Optional[List[str]] = None,
12961309
pg: Optional[dist.ProcessGroup] = None,
12971310
):
@@ -1305,7 +1318,7 @@ def load(
13051318
rank = dist.get_rank(group=pg)
13061319
world_size = dist.get_world_size(group=pg)
13071320

1308-
for table_name, storage, counter in zip(
1321+
for table_name, storage, counter_table in zip(
13091322
self._table_names, self._storages, self._admission_counter
13101323
):
13111324
if table_name not in set(table_names):
@@ -1338,11 +1351,16 @@ def load(
13381351
include_optim=optim,
13391352
)
13401353

1341-
if counter is None:
1354+
if not counter:
1355+
continue
1356+
if counter_table is None:
1357+
warnings.warn(
1358+
f"Counter table is none and will not load for table: {table_name}"
1359+
)
13421360
continue
13431361
num_counter_key_files = len(counter_key_files)
13441362
for i in range(num_counter_key_files):
1345-
counter.load(counter_key_files[i], counter_frequency_files[i])
1363+
counter_table.load(counter_key_files[i], counter_frequency_files[i])
13461364

13471365
def export_keys_values(
13481366
self, table_name: str, device: torch.device, batch_size: int = 65536

corelib/dynamicemb/dynamicemb/dump_load.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def DynamicEmbDump(
9393
model: nn.Module,
9494
table_names: Optional[Dict[str, List[str]]] = None,
9595
optim: Optional[bool] = False,
96+
counter: Optional[bool] = False,
9697
pg: dist.ProcessGroup = dist.group.WORLD,
9798
allow_overwrite: bool = False,
9899
) -> None:
@@ -115,6 +116,8 @@ def DynamicEmbDump(
115116
and the value is a list of dynamic embedding table names within that collection. Defaults to None.
116117
optim : Optional[bool], optional
117118
Whether to dump the optimizer states. Defaults to False.
119+
counter : Optional[bool], optional
120+
Whether to dump the embedding admission counter table. Defaults to False.
118121
pg : Optional[dist.ProcessGroup], optional
119122
The process group used to control the communication scope in the dump. Defaults to None.
120123
@@ -175,6 +178,7 @@ def DynamicEmbDump(
175178
dynamic_emb_module.dump(
176179
full_collection_path,
177180
optim=optim,
181+
counter=counter,
178182
table_names=table_names_to_dump,
179183
pg=pg,
180184
)
@@ -197,6 +201,7 @@ def DynamicEmbLoad(
197201
model: nn.Module,
198202
table_names: Optional[List[str]] = None,
199203
optim: bool = False,
204+
counter: bool = False,
200205
pg: dist.ProcessGroup = dist.group.WORLD,
201206
):
202207
"""
@@ -216,6 +221,8 @@ def DynamicEmbLoad(
216221
and the value is a list of dynamic embedding table names within that collection. Defaults to None.
217222
optim : bool, optional
218223
Whether to load the optimizer states. Defaults to False.
224+
counter : bool, optional
225+
Whether to load the embedding admission counter table. Defaults to False.
219226
pg : Optional[dist.ProcessGroup], optional
220227
The process group used to control the communication scope in the load. Defaults to None.
221228
@@ -257,6 +264,7 @@ def DynamicEmbLoad(
257264
dynamic_emb_module.load(
258265
full_collection_path,
259266
optim=optim,
267+
counter=counter,
260268
table_names=table_names_to_load,
261269
pg=pg,
262270
)

corelib/dynamicemb/dynamicemb/embedding_admission.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def dump(self, key_file, counter_file) -> None:
106106
key_file (str): the file path of keys.
107107
counter_file (str): the file path of frequencies.
108108
"""
109+
print(f"Counter size: {self.table_.size()}")
109110
self.table_.dump(key_file, {self.score_name_: counter_file})
110111

111112

corelib/dynamicemb/test/unit_tests/test_embedding_dump_load.py

Lines changed: 72 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@
2424
import torch
2525
import torch.distributed as dist
2626
import torch.nn as nn
27-
from dynamicemb import DynamicEmbScoreStrategy, DynamicEmbTableOptions
27+
from dynamicemb import (
28+
DynamicEmbScoreStrategy,
29+
DynamicEmbTableOptions,
30+
FrequencyAdmissionStrategy,
31+
)
2832
from dynamicemb.dump_load import (
2933
DynamicEmbDump,
3034
DynamicEmbLoad,
@@ -38,6 +42,7 @@
3842
from dynamicemb.embedding_admission import KVCounter
3943
from dynamicemb.get_planner import get_planner
4044
from dynamicemb.key_value_table import batched_export_keys_values
45+
from dynamicemb.scored_hashtable import ScoreArg, ScorePolicy
4146
from dynamicemb.shard import DynamicEmbeddingCollectionSharder
4247
from dynamicemb.types import AdmissionStrategy
4348
from dynamicemb.utils import TORCHREC_TYPES
@@ -318,6 +323,40 @@ def create_model(
318323
return model
319324

320325

326+
def check_counter_table_checkpoint(x, y):
327+
device = torch.cuda.current_device()
328+
tables_x = get_dynamic_emb_module(x)
329+
tables_y = get_dynamic_emb_module(y)
330+
331+
for table_x, table_y in zip(tables_x, tables_y):
332+
for cnt_tx, cnt_ty in zip(table_x, table_y):
333+
assert cnt_tx.table_.size() == cnt_ty.table_.size()
334+
335+
for keys, named_scores in cnt_tx._batched_export_keys_scores(
336+
cnt_tx.table_.score_names_, torch.device(f"cuda:{device}")
337+
):
338+
if keys.numel() == 0:
339+
continue
340+
freq_name = cnt_tx.table_.score_names_[0]
341+
frequencies = named_scores[freq_name]
342+
343+
score_args_lookup = [
344+
ScoreArg(
345+
name=freq_name,
346+
value=torch.zeros_like(frequencies),
347+
policy=ScorePolicy.CONST,
348+
is_return=True,
349+
)
350+
]
351+
founds = torch.empty(
352+
keys.numel(), dtype=torch.bool, device=device
353+
).fill_(False)
354+
355+
cnt_ty.lookup(keys, score_args_lookup, founds)
356+
357+
assert torch.equal(frequencies, score_args_lookup)
358+
359+
321360
@click.command()
322361
@click.option("--num-embedding-collections", type=int, required=True)
323362
@click.option("--num-embeddings", type=str, required=True)
@@ -336,6 +375,7 @@ def create_model(
336375
required=True,
337376
)
338377
@click.option("--optim", type=bool, required=True)
378+
@click.option("--counter", type=bool, required=True)
339379
def test_model_load_dump(
340380
num_embedding_collections: int,
341381
num_embeddings: str,
@@ -346,6 +386,7 @@ def test_model_load_dump(
346386
mode: str,
347387
save_path: str,
348388
optim: bool,
389+
counter: bool,
349390
batch_size: int = 128,
350391
num_iterations: int = 10,
351392
):
@@ -367,6 +408,9 @@ def test_model_load_dump(
367408
embedding_dim=embedding_dim,
368409
optimizer_kwargs=optimizer_kwargs,
369410
score_strategy=score_strategy_,
411+
admit_strategy=FrequencyAdmissionStrategy(
412+
threshold=2 if counter else 1,
413+
),
370414
)
371415

372416
kjts, feature_names, all_kjts = generate_sparse_feature(
@@ -388,7 +432,7 @@ def test_model_load_dump(
388432

389433
if mode == "dump":
390434
shutil.rmtree(save_path, ignore_errors=True)
391-
DynamicEmbDump(save_path, ref_model, optim=optim)
435+
DynamicEmbDump(save_path, ref_model, optim=optim, counter=counter)
392436

393437
if mode == "load":
394438
model = create_model(
@@ -397,16 +441,24 @@ def test_model_load_dump(
397441
embedding_dim=embedding_dim,
398442
optimizer_kwargs=optimizer_kwargs,
399443
score_strategy=score_strategy_,
444+
admit_strategy=FrequencyAdmissionStrategy(
445+
threshold=2 if counter else 1,
446+
),
400447
)
401448

402-
DynamicEmbLoad(save_path, model, optim=optim)
449+
DynamicEmbLoad(save_path, model, optim=optim, counter=counter)
450+
451+
if counter:
452+
check_counter_table_checkpoint(model, ref_model)
403453

404454
table_name_to_key_score_dict = {}
405455
for _, _, sharded_module in find_sharded_modules(model):
406456
dynamic_emb_modules = get_dynamic_emb_module(sharded_module)
407457
for dynamic_emb_module in dynamic_emb_modules:
408-
for table_name, table in zip(
409-
dynamic_emb_module.table_names, dynamic_emb_module.tables
458+
for table_name, table, counter_table in zip(
459+
dynamic_emb_module.table_names,
460+
dynamic_emb_module.tables,
461+
dynamic_emb_module._admission_counter,
410462
):
411463
key_to_score = {}
412464
for batched_key, _, _, batched_score in batched_export_keys_values(
@@ -416,6 +468,21 @@ def test_model_load_dump(
416468
batched_key.tolist(), batched_score.tolist()
417469
):
418470
key_to_score[key] = score
471+
472+
for (
473+
keys,
474+
named_scores,
475+
) in counter_table.table_._batched_export_keys_scores(
476+
counter_table.table_.score_names_, torch.device(f"cpu")
477+
):
478+
if keys.numel() == 0:
479+
continue
480+
freq_name = counter_table.table_.score_names_[0]
481+
frequencies = named_scores[freq_name]
482+
483+
for key, score in zip(keys.tolist(), frequencies.tolist()):
484+
key_to_score[key] = score
485+
419486
table_name_to_key_score_dict[table_name] = key_to_score
420487

421488
for embedding_collection_idx, embedding_idx in product(

corelib/dynamicemb/test/unit_tests/test_embedding_dump_load.sh

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,45 +8,24 @@ NUM_GPUS=(1 4)
88
OPTIMIZER_TYPE=("adam" "sgd" "adagrad" "rowwise_adagrad")
99
INCLUDE_OPTIM=("True" "False")
1010
SCORE_STRATEGY=("timestamp" "lfu" "step")
11+
INCLUDE_COUNTER=("True" "False")
1112

1213
for num_gpus in ${NUM_GPUS[@]}; do
1314
for optimizer_type in ${OPTIMIZER_TYPE[@]}; do
1415
for include_optim in ${INCLUDE_OPTIM[@]}; do
15-
for score_strategy in ${SCORE_STRATEGY[@]}; do
16-
echo "num_gpus: $num_gpus, optimizer_type: $optimizer_type, include_optim: $include_optim, score_strategy: $score_strategy"
17-
torchrun \
18-
--nnodes 1 \
19-
--nproc_per_node $num_gpus \
20-
./test/unit_tests/test_embedding_dump_load.py \
21-
--optimizer-type ${optimizer_type} \
22-
--score-strategy ${score_strategy} \
23-
--mode "dump" \
24-
--optim ${include_optim} \
25-
--save-path "debug_weight_${optimizer_type}_${num_gpus}_${include_optim}_${score_strategy}" \
26-
--num-embedding-collections $NUM_EMBEDDING_COLLECTIONS \
27-
--num-embeddings $NUM_EMBEDDINGS \
28-
--multi-hot-sizes $MULTI_HOT_SIZES \
29-
--embedding-dim 16 || exit 1
30-
done
31-
done
32-
done
33-
done
34-
35-
for num_load_gpus in ${NUM_GPUS[@]}; do
36-
for num_dump_gpus in ${NUM_GPUS[@]}; do
37-
for optimizer_type in ${OPTIMIZER_TYPE[@]}; do
38-
for include_optim in ${INCLUDE_OPTIM[@]}; do
16+
for include_counter in ${INCLUDE_COUNTER[@]}; do
3917
for score_strategy in ${SCORE_STRATEGY[@]}; do
40-
echo "num_load_gpus: $num_load_gpus, num_dump_gpus: $num_dump_gpus, optimizer_type: $optimizer_type, include_optim: $include_optim, score_strategy: $score_strategy"
18+
echo "num_gpus: $num_gpus, optimizer_type: $optimizer_type, include_optim: $include_optim, include_counter: $include_counter, score_strategy: $score_strategy"
4119
torchrun \
4220
--nnodes 1 \
43-
--nproc_per_node $num_load_gpus \
21+
--nproc_per_node $num_gpus \
4422
./test/unit_tests/test_embedding_dump_load.py \
4523
--optimizer-type ${optimizer_type} \
4624
--score-strategy ${score_strategy} \
47-
--mode "load" \
25+
--mode "dump" \
4826
--optim ${include_optim} \
49-
--save-path "debug_weight_${optimizer_type}_${num_dump_gpus}_${include_optim}_${score_strategy}" \
27+
--counter ${include_counter} \
28+
--save-path "debug_weight_${optimizer_type}_${num_gpus}_${include_optim}_${include_counter}_${score_strategy}" \
5029
--num-embedding-collections $NUM_EMBEDDING_COLLECTIONS \
5130
--num-embeddings $NUM_EMBEDDINGS \
5231
--multi-hot-sizes $MULTI_HOT_SIZES \
@@ -55,4 +34,32 @@ for num_load_gpus in ${NUM_GPUS[@]}; do
5534
done
5635
done
5736
done
37+
done
38+
39+
for num_load_gpus in ${NUM_GPUS[@]}; do
40+
for num_dump_gpus in ${NUM_GPUS[@]}; do
41+
for optimizer_type in ${OPTIMIZER_TYPE[@]}; do
42+
for include_optim in ${INCLUDE_OPTIM[@]}; do
43+
for include_counter in ${INCLUDE_COUNTER[@]}; do
44+
for score_strategy in ${SCORE_STRATEGY[@]}; do
45+
echo "num_load_gpus: $num_load_gpus, num_dump_gpus: $num_dump_gpus, optimizer_type: $optimizer_type, include_optim: $include_optim, include_counter: $include_counter, score_strategy: $score_strategy"
46+
torchrun \
47+
--nnodes 1 \
48+
--nproc_per_node $num_load_gpus \
49+
./test/unit_tests/test_embedding_dump_load.py \
50+
--optimizer-type ${optimizer_type} \
51+
--score-strategy ${score_strategy} \
52+
--mode "load" \
53+
--optim ${include_optim} \
54+
--counter ${include_counter} \
55+
--save-path "debug_weight_${optimizer_type}_${num_dump_gpus}_${include_optim}_${include_counter}_${score_strategy}" \
56+
--num-embedding-collections $NUM_EMBEDDING_COLLECTIONS \
57+
--num-embeddings $NUM_EMBEDDINGS \
58+
--multi-hot-sizes $MULTI_HOT_SIZES \
59+
--embedding-dim 16 || exit 1
60+
done
61+
done
62+
done
63+
done
64+
done
5865
done

0 commit comments

Comments
 (0)