Skip to content

Commit 969916d

Browse files
ds-hwangcopybara-github
authored andcommitted
Cache GetAllDatasetParams.
Every GetDatasetParams(self, dataset) call calls GetAllDatasetParams(), which instantiates all dataset and tasks. In MMASR+intrainer case, there are 80 dataset and corresponding Task_XXX, which are 2B params model. It takes more than 10 mins. The initial GetAllDatasetParams() API is unfortunate. It supposed to be a generator or keep only factory methods, instead of actual instantiation. PiperOrigin-RevId: 492103447
1 parent 794110e commit 969916d

File tree

4 files changed

+43
-28
lines changed

4 files changed

+43
-28
lines changed

lingvo/core/base_model_params.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
"""BaseModelParams class definition."""
16-
16+
import functools
1717

1818
from lingvo import datasets
1919
from lingvo.core import base_input_generator
@@ -35,6 +35,13 @@ class DatasetError(Exception):
3535
pass
3636

3737

38+
# Cache the GetAllDatasetParams() results, which may include many Task_XXX.
39+
# Note: datasets inspect _BaseModelParams, so it can not have a cached method.
40+
@functools.lru_cache(maxsize=8)
41+
def GetCachedAllDatasetParams(obj):
42+
return obj.GetAllDatasetParams()
43+
44+
3845
class _BaseModelParams:
3946
"""Base class for storing model Params for a single experiment."""
4047

@@ -45,12 +52,10 @@ def GetAllDatasetParams(self):
4552
be treated as dataset specifications.
4653
4754
Returns:
48-
A dict of {dataset_name: dataset_params}.
49-
50-
Raises:
51-
GetAllDatasetParamsNotImplementedError: by default.
55+
A dict of {dataset_name: dataset_params}. If not implemented, returns
56+
None.
5257
"""
53-
raise datasets.GetAllDatasetParamsNotImplementedError(type(self))
58+
return None
5459

5560
def GetDatasetParams(self, dataset):
5661
"""Convenience function that returns the param for the given dataset name.
@@ -65,16 +70,13 @@ def GetDatasetParams(self, dataset):
6570
Raises:
6671
DatasetError: if there is not a `${dataset}` method defined under `cls`.
6772
"""
68-
try:
69-
all_datasets = self.GetAllDatasetParams()
70-
if dataset not in all_datasets:
73+
all_datasets = GetCachedAllDatasetParams(self)
74+
if all_datasets is not None:
75+
if dataset not in all_datasets.keys():
7176
# When `GetAllDatasetParams` is defined, all public methods are ignored.
7277
raise DatasetError(f'Dataset {dataset} not found; '
7378
f'available datasets are: {all_datasets.keys()}')
7479
return all_datasets.get(dataset)
75-
except datasets.GetAllDatasetParamsNotImplementedError:
76-
# Fall through the legacy path.
77-
pass
7880

7981
try:
8082
f = getattr(self, dataset)

lingvo/datasets.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,6 @@ class DatasetFunctionError(TypeError):
3131
pass
3232

3333

34-
class GetAllDatasetParamsNotImplementedError(NotImplementedError):
35-
pass
36-
37-
3834
def GetDatasets(cls: Any, warn_on_error: bool = True) -> List[str]:
3935
"""Returns the list of dataset functions (e.g., Train, Dev, ...).
4036
@@ -55,7 +51,6 @@ def GetDatasets(cls: Any, warn_on_error: bool = True) -> List[str]:
5551
DatasetFunctionError: if the cls contains public methods that cannot be used
5652
as datasets, and warn_on_error is False.
5753
"""
58-
5954
mdl_params = None
6055
if inspect.isclass(cls):
6156
try:
@@ -65,13 +60,11 @@ def GetDatasets(cls: Any, warn_on_error: bool = True) -> List[str]:
6560
else:
6661
mdl_params = cls
6762

68-
if mdl_params:
69-
try:
70-
all_datasets = mdl_params.GetAllDatasetParams()
63+
if mdl_params and hasattr(mdl_params, 'GetAllDatasetParams'):
64+
all_datasets = mdl_params.GetAllDatasetParams()
65+
if all_datasets is not None:
7166
# When `GetAllDatasetParams` is defined, all public methods are ignored.
7267
return sorted(list(all_datasets.keys()))
73-
except GetAllDatasetParamsNotImplementedError:
74-
pass
7568

7669
datasets = []
7770
for name, _ in inspect.getmembers(cls, inspect.isroutine):

lingvo/model_registry.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,22 @@ def GetClass(cls, class_key):
243243
class_key)
244244
return all_params[class_key]
245245

246+
@classmethod
247+
@functools.lru_cache(maxsize=8)
248+
def GetCachedObject(cls, class_key):
249+
"""Returns a ModelParams class object with the given `class_key`.
250+
251+
To reuse the same object for given `class_key`, it uses
252+
functools.lru_cache annotation.
253+
254+
Args:
255+
class_key: string key of the ModelParams subclass to return.
256+
257+
Returns:
258+
A instance of `~.base_model_params._BaseModelParams`.
259+
"""
260+
return cls.GetClass(class_key)()
261+
246262
@classmethod
247263
def GetParamsFromModelParamsObject(cls, model_params_obj, dataset_name):
248264
"""Returns a `Params` object for given _BaseModelParams instance.
@@ -281,7 +297,7 @@ def GetParamsFromModelParamsObject(cls, model_params_obj, dataset_name):
281297
return cfg
282298

283299
@classmethod
284-
def GetParams(cls, class_key, dataset_name):
300+
def GetParams(cls, class_key, dataset_name, cache=True):
285301
"""Constructs a `Params` object for given model and dataset, obeying flags.
286302
287303
In case of default model, params may be updated based on the flags
@@ -290,12 +306,17 @@ def GetParams(cls, class_key, dataset_name):
290306
Args:
291307
class_key: String class key (i.e. `image.mnist.LeNet5`).
292308
dataset_name: Method to generate dataset params (i.e. 'Test').
309+
cache: Whether to cache the params given class_key and dataset_name.
293310
294311
Returns:
295312
Full `~.hyperparams.Params` for the model class.
296313
"""
297-
model_params_cls = cls.GetClass(class_key)
298-
return cls.GetParamsFromModelParamsObject(model_params_cls(), dataset_name)
314+
if cache:
315+
# Reuse the cached object to reused the cached obj.GetAllDatasetParams().
316+
model_params_obj = cls.GetCachedObject(class_key)
317+
else:
318+
model_params_obj = cls.GetClass(class_key)()
319+
return cls.GetParamsFromModelParamsObject(model_params_obj, dataset_name)
299320

300321
@classmethod
301322
def GetProgramScheduleFromModelParamsObject(cls, model_params_obj):

lingvo/models_test_helper.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,8 @@ def _ValidateEMA(self, name, mdl):
129129
def _testOneModelParams(self, registry, name):
130130
with tf.Graph().as_default():
131131
model_params = registry.GetClass(name)()
132-
try:
133-
all_datasets = model_params.GetAllDatasetParams()
134-
except datasets.GetAllDatasetParamsNotImplementedError:
132+
all_datasets = model_params.GetAllDatasetParams()
133+
if all_datasets is None:
135134
all_datasets = {}
136135
for dataset_name in datasets.GetDatasets(model_params):
137136
try:

0 commit comments

Comments
 (0)