Skip to content

Commit cc84214

Browse files
Googlertensorflower-gardener
authored andcommitted
Correctly expose sub-distribution parameters in distributions created by
inflated_factory. PiperOrigin-RevId: 471879189
1 parent 1999fbe commit cc84214

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

tensorflow_probability/python/distributions/inflated.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,12 +224,16 @@ def inflated_factory(default_name, distribution_class, inflated_loc,
224224
def my_init(self,
225225
inflated_loc_logits=None, inflated_loc_probs=None,
226226
name=default_name, **kwargs):
227+
parameters = dict(locals())
227228
if 'distribution' in kwargs:
228229
dist = kwargs['distribution']
229230
else:
230231
dist = distribution_class(**{**kwargs, **more_kwargs})
231232
Inflated.__init__(self, dist, inflated_loc_logits, inflated_loc_probs,
232233
inflated_loc, name=name)
234+
# pylint: disable=protected-access
235+
self._parameters = {**parameters, **more_kwargs}
236+
# pylint: enable=protected-access
233237

234238
def my_parameter_properties(unused_cls, dtype, num_classes=None):
235239
return dict(

0 commit comments

Comments
 (0)