Skip to content

Commit 5926cdb

Browse files
markmcdcopybara-github
authored andcommitted
Add support for using the module dir as base_dir.
Currently, if we try to document `/path/module/__init__.py` using `base_dir=['/path/module']`, everything under `/path/module` is trimmed for being outside of the base dir. This is in prep for adding more modules to the TF Lite Support library, where submodules come from disparate parts of the repo. PiperOrigin-RevId: 446334715
1 parent 5ae1433 commit 5926cdb

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

tools/tensorflow_docs/api_generator/public_api.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,8 +331,9 @@ def __call__(self, path: Sequence[str], parent: Any,
331331
# contents will get filtered when the submodules are checked.
332332
if len(mod_base_dirs) == 1:
333333
mod_base_dir = mod_base_dirs[0]
334-
# Check that module is in one of the `self._base_dir`s
335-
if not any(base in mod_base_dir.parents for base in self.base_dirs):
334+
# Check that module is, or is in one of the `self._base_dir`s
335+
if not (any(base in mod_base_dir.parents for base in self.base_dirs) or
336+
mod_base_dir in self.base_dirs):
336337
continue
337338
yield name, child
338339

tools/tensorflow_docs/api_generator/public_api_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,21 @@ def test_filter_base_dirs(self):
210210
('sub2', module.sub2)])
211211
self.assertEqual([('a', module.a), ('sub1', module.sub1)], list(result))
212212

213+
def test_filter_base_dir_pointing_to_submodule_dir(self):
214+
module = types.ModuleType('module')
215+
module.__file__ = '/1/2/3/module'
216+
module.submodule = types.ModuleType('submodule')
217+
module.submodule.__file__ = '/1/2/3/submodule/__init__.py'
218+
219+
test_filter = public_api.FilterBaseDirs(
220+
base_dirs=[pathlib.Path('/1/2/3/submodule')])
221+
result = test_filter(
222+
path=('module',),
223+
parent=module,
224+
children=[('submodule', module.submodule)])
225+
226+
self.assertEqual([('submodule', module.submodule)], list(result))
227+
213228

214229
if __name__ == '__main__':
215230
absltest.main()

0 commit comments

Comments
 (0)