@@ -368,14 +368,26 @@ def __init__(
368
368
torch .broadcast_shapes (coefficient .shape , offset .shape )
369
369
370
370
self ._d = d
371
- self .register_buffer ("coefficient " , coefficient )
372
- self .register_buffer ("offset " , offset )
371
+ self .register_buffer ("_coefficient " , coefficient )
372
+ self .register_buffer ("_offset " , offset )
373
373
self .batch_shape = batch_shape
374
374
self .transform_on_train = transform_on_train
375
375
self .transform_on_eval = transform_on_eval
376
376
self .transform_on_fantasize = transform_on_fantasize
377
377
self .reverse = reverse
378
378
379
+ @property
380
+ def coefficient (self ) -> Tensor :
381
+ r"""The tensor of linear coefficients."""
382
+ coeff = self ._coefficient
383
+ return coeff if self .learn_coefficients and self .training else coeff .detach ()
384
+
385
+ @property
386
+ def offset (self ) -> Tensor :
387
+ r"""The tensor of offset coefficients."""
388
+ offset = self ._offset
389
+ return offset if self .learn_coefficients and self .training else offset .detach ()
390
+
379
391
@property
380
392
def learn_coefficients (self ) -> bool :
381
393
return getattr (self , "_learn_coefficients" , False )
@@ -459,8 +471,8 @@ def _check_shape(self, X: Tensor) -> None:
459
471
460
472
def _to (self , X : Tensor ) -> None :
461
473
r"""Makes coefficient and offset have same device and dtype as X."""
462
- self .coefficient = self .coefficient .to (X )
463
- self .offset = self .offset .to (X )
474
+ self ._coefficient = self .coefficient .to (X )
475
+ self ._offset = self .offset .to (X )
464
476
465
477
def _update_coefficients (self , X : Tensor ) -> None :
466
478
r"""Updates affine coefficients. Implemented by subclasses,
@@ -569,9 +581,9 @@ def _update_coefficients(self, X) -> None:
569
581
# Aggregate mins and ranges over extra batch and marginal dims
570
582
batch_ndim = min (len (self .batch_shape ), X .ndim - 2 ) # batch rank of `X`
571
583
reduce_dims = (* range (X .ndim - batch_ndim - 2 ), X .ndim - 2 )
572
- self .offset = torch .amin (X , dim = reduce_dims ).unsqueeze (- 2 )
573
- self .coefficient = torch .amax (X , dim = reduce_dims ).unsqueeze (- 2 ) - self .offset
574
- self .coefficient .clamp_ (min = self .min_range )
584
+ self ._offset = torch .amin (X , dim = reduce_dims ).unsqueeze (- 2 )
585
+ self ._coefficient = torch .amax (X , dim = reduce_dims ).unsqueeze (- 2 ) - self .offset
586
+ self ._coefficient .clamp_ (min = self .min_range )
575
587
576
588
577
589
class InputStandardize (AffineInputTransform ):
@@ -641,11 +653,11 @@ def _update_coefficients(self, X: Tensor) -> None:
641
653
# Aggregate means and standard deviations over extra batch and marginal dims
642
654
batch_ndim = min (len (self .batch_shape ), X .ndim - 2 ) # batch rank of `X`
643
655
reduce_dims = (* range (X .ndim - batch_ndim - 2 ), X .ndim - 2 )
644
- coefficient , self .offset = (
656
+ coefficient , self ._offset = (
645
657
values .unsqueeze (- 2 )
646
658
for values in torch .std_mean (X , dim = reduce_dims , unbiased = True )
647
659
)
648
- self .coefficient = coefficient .clamp_ (min = self .min_std )
660
+ self ._coefficient = coefficient .clamp_ (min = self .min_std )
649
661
650
662
651
663
class Round (InputTransform , Module ):
0 commit comments