Skip to content

Commit 03f9e01

Browse files
levythufacebook-github-bot
authored andcommitted
Make deep copy when constructing cache params (#3219)
Summary: Pull Request resolved: #3219 During debugging, we realized that it's possible that a feature may appear in two arch as different tables. Since sharding constraint is formed by table name, it'll craft the same constraint for both tables. During proposer, if one table's constraint get modified, e.g. by EmbeddingOffloadCacheScaleupProposer, the other will change automatically. This is not expected as sharding plan should treat them individually. This diff fixed it. Reviewed By: iamzainhuda, aliafzal Differential Revision: D78287465 fbshipit-source-id: 6eab7a47f72e6501eadcd1bc2dbe88f53bb0b984
1 parent 63c914a commit 03f9e01

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

torchrec/distributed/planner/enumerators.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import copy
1111
import logging
12+
from copy import deepcopy
1213
from typing import Dict, List, Optional, Set, Tuple, Union
1314

1415
from torch import nn
@@ -364,16 +365,17 @@ def _extract_constraints_for_param(
364365
key_value_params = None
365366

366367
if constraints and constraints.get(name):
368+
# For nested fields return a deep copy instead
367369
input_lengths = constraints[name].pooling_factors
368370
col_wise_shard_dim = constraints[name].min_partition
369-
cache_params = constraints[name].cache_params
371+
cache_params = deepcopy(constraints[name].cache_params)
370372
enforce_hbm = constraints[name].enforce_hbm
371373
stochastic_rounding = constraints[name].stochastic_rounding
372374
bounds_check_mode = constraints[name].bounds_check_mode
373375
feature_names = constraints[name].feature_names
374376
output_dtype = constraints[name].output_dtype
375377
device_group = constraints[name].device_group
376-
key_value_params = constraints[name].key_value_params
378+
key_value_params = deepcopy(constraints[name].key_value_params)
377379

378380
return (
379381
input_lengths,

0 commit comments

Comments
 (0)