Skip to content

Commit 0f8caa2

Browse files
ds-hwangcopybara-github
authored andcommitted
Add tests for the case the model whose datasets are defined by GetAllDatasetParams and public methods.
It's possible that some datasets are defined by GetAllDatasetParams, and others are defined by public methods. In this case, current datasets code ignores public methods, while _BaseModelParams takes care of this case. This is right behavior. Add comments and unittests to explicitly show it. PiperOrigin-RevId: 491466065
1 parent 0e43ced commit 0f8caa2

File tree

4 files changed

+32
-0
lines changed

4 files changed

+32
-0
lines changed

lingvo/core/base_model_params.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def GetDatasetParams(self, dataset):
6868
try:
6969
all_datasets = self.GetAllDatasetParams()
7070
if dataset not in all_datasets:
71+
# When `GetAllDatasetParams` is defined, all public methods are ignored.
7172
raise DatasetError(f'Dataset {dataset} not found; '
7273
f'available datasets are: {all_datasets.keys()}')
7374
return all_datasets.get(dataset)

lingvo/core/base_model_params_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,15 @@ def GetAllDatasetParams(self):
2424
raise NotImplementedError('test error')
2525

2626

27+
class PartialDatasetParams(base_model_params.SingleTaskModelParams):
28+
29+
def GetAllDatasetParams(self):
30+
return dict(Test_X1=self.Test())
31+
32+
def Test_X2(self):
33+
pass
34+
35+
2736
class BaseModelParamsTest(test_utils.TestCase):
2837

2938
def testGetDatasetParams_SingleTaskModelParams(self):
@@ -47,6 +56,13 @@ def testGetDatasetParams_NotImplementedError(self):
4756
with self.assertRaisesRegexp(NotImplementedError, 'test error'):
4857
dummy_model.GetDatasetParams('Train')
4958

59+
def testGetDatasetParams_PartialDatasetParams(self):
60+
dummy_model = PartialDatasetParams()
61+
self.assertEqual(dummy_model.Test(),
62+
dummy_model.GetDatasetParams('Test_X1'))
63+
with self.assertRaises(base_model_params.DatasetError):
64+
dummy_model.GetDatasetParams('Test_X2')
65+
5066

5167
if __name__ == '__main__':
5268
test_utils.main()

lingvo/datasets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def GetDatasets(cls: Any, warn_on_error: bool = True) -> List[str]:
6868
if mdl_params:
6969
try:
7070
all_datasets = mdl_params.GetAllDatasetParams()
71+
# When `GetAllDatasetParams` is defined, all public methods are ignored.
7172
return sorted(list(all_datasets.keys()))
7273
except GetAllDatasetParamsNotImplementedError:
7374
pass

lingvo/datasets_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,20 @@ def GetAllDatasetParams(self):
121121
self.assertAllEqual(['Dev', 'Train'],
122122
datasets.GetDatasets(DummyDatasetHolder()))
123123

124+
def testGetDatasetsWithGetAllDatasetParamsAndPublicMethods(self):
125+
126+
class DummyDatasetHolder(base_model_params._BaseModelParams):
127+
128+
def UnexpectedDatasetName1(self):
129+
pass
130+
131+
def GetAllDatasetParams(self):
132+
return {'UnexpectedDatasetName2': None}
133+
134+
found_datasets = datasets.GetDatasets(DummyDatasetHolder)
135+
136+
self.assertAllEqual(['UnexpectedDatasetName2'], found_datasets)
137+
124138
def testGetDatasetsOnClassWithPositionalArgumentInit(self):
125139

126140
class DummyDatasetHolder(base_model_params._BaseModelParams):

0 commit comments

Comments
 (0)