Skip to content

Commit 91219d6

Browse files
author
Johannes Ballé
committed
Sets __all__ for entropy models and range coding ops.
1 parent 5e98dfd commit 91219d6

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

python/layers/entropy_models.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,16 @@
3030
from tensorflow_compression.python.ops import range_coding_ops
3131

3232

33+
__all__ = [
34+
"EntropyModel",
35+
"EntropyBottleneck",
36+
"SymmetricConditional",
37+
"GaussianConditional",
38+
"LogisticConditional",
39+
"LaplacianConditional",
40+
]
41+
42+
3343
class EntropyModel(tf.keras.layers.Layer):
3444
"""Entropy model (base class).
3545

python/ops/range_coding_ops.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,16 @@
2424
from tensorflow.python.framework import load_library
2525
from tensorflow.python.platform import resource_loader
2626

27+
28+
__all__ = list()
29+
_ops = dict()
2730
_range_coding_ops = load_library.load_op_library(
2831
resource_loader.get_path_to_datafile("../../_range_coding_ops.so"))
29-
30-
pmf_to_quantized_cdf = _range_coding_ops.pmf_to_quantized_cdf
31-
range_decode = _range_coding_ops.range_decode
32-
range_encode = _range_coding_ops.range_encode
32+
for name in dir(_range_coding_ops):
33+
if name.startswith("_"):
34+
continue
35+
if name in ("LIB_HANDLE", "OP_LIST", "deprecated_endpoints", "tf_export"):
36+
continue
37+
__all__.append(name)
38+
_ops[name] = getattr(_range_coding_ops, name)
39+
globals().update(_ops)

0 commit comments

Comments
 (0)