Skip to content

Commit de31154

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Detach coefficient and offset in AffineTransform in eval mode (meta-pytorch#1642)
Summary: Pull Request resolved: meta-pytorch#1642 See meta-pytorch#1635. These should only retain grad if learning the coefficient and while in train mode. Reviewed By: SebastianAment Differential Revision: D42700421 fbshipit-source-id: 3853071f256b268ef46a239a91b13486199d6cac
1 parent 72e872a commit de31154

File tree

2 files changed

+47
-9
lines changed

2 files changed

+47
-9
lines changed

botorch/models/transforms/input.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -368,14 +368,26 @@ def __init__(
368368
torch.broadcast_shapes(coefficient.shape, offset.shape)
369369

370370
self._d = d
371-
self.register_buffer("coefficient", coefficient)
372-
self.register_buffer("offset", offset)
371+
self.register_buffer("_coefficient", coefficient)
372+
self.register_buffer("_offset", offset)
373373
self.batch_shape = batch_shape
374374
self.transform_on_train = transform_on_train
375375
self.transform_on_eval = transform_on_eval
376376
self.transform_on_fantasize = transform_on_fantasize
377377
self.reverse = reverse
378378

379+
@property
380+
def coefficient(self) -> Tensor:
381+
r"""The tensor of linear coefficients."""
382+
coeff = self._coefficient
383+
return coeff if self.learn_coefficients and self.training else coeff.detach()
384+
385+
@property
386+
def offset(self) -> Tensor:
387+
r"""The tensor of offset coefficients."""
388+
offset = self._offset
389+
return offset if self.learn_coefficients and self.training else offset.detach()
390+
379391
@property
380392
def learn_coefficients(self) -> bool:
381393
return getattr(self, "_learn_coefficients", False)
@@ -459,8 +471,8 @@ def _check_shape(self, X: Tensor) -> None:
459471

460472
def _to(self, X: Tensor) -> None:
461473
r"""Makes coefficient and offset have same device and dtype as X."""
462-
self.coefficient = self.coefficient.to(X)
463-
self.offset = self.offset.to(X)
474+
self._coefficient = self.coefficient.to(X)
475+
self._offset = self.offset.to(X)
464476

465477
def _update_coefficients(self, X: Tensor) -> None:
466478
r"""Updates affine coefficients. Implemented by subclasses,
@@ -569,9 +581,9 @@ def _update_coefficients(self, X) -> None:
569581
# Aggregate mins and ranges over extra batch and marginal dims
570582
batch_ndim = min(len(self.batch_shape), X.ndim - 2) # batch rank of `X`
571583
reduce_dims = (*range(X.ndim - batch_ndim - 2), X.ndim - 2)
572-
self.offset = torch.amin(X, dim=reduce_dims).unsqueeze(-2)
573-
self.coefficient = torch.amax(X, dim=reduce_dims).unsqueeze(-2) - self.offset
574-
self.coefficient.clamp_(min=self.min_range)
584+
self._offset = torch.amin(X, dim=reduce_dims).unsqueeze(-2)
585+
self._coefficient = torch.amax(X, dim=reduce_dims).unsqueeze(-2) - self.offset
586+
self._coefficient.clamp_(min=self.min_range)
575587

576588

577589
class InputStandardize(AffineInputTransform):
@@ -641,11 +653,11 @@ def _update_coefficients(self, X: Tensor) -> None:
641653
# Aggregate means and standard deviations over extra batch and marginal dims
642654
batch_ndim = min(len(self.batch_shape), X.ndim - 2) # batch rank of `X`
643655
reduce_dims = (*range(X.ndim - batch_ndim - 2), X.ndim - 2)
644-
coefficient, self.offset = (
656+
coefficient, self._offset = (
645657
values.unsqueeze(-2)
646658
for values in torch.std_mean(X, dim=reduce_dims, unbiased=True)
647659
)
648-
self.coefficient = coefficient.clamp_(min=self.min_std)
660+
self._coefficient = coefficient.clamp_(min=self.min_std)
649661

650662

651663
class Round(InputTransform, Module):

test/models/transforms/test_input.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,20 @@ def test_normalize(self):
165165
self.assertTrue(
166166
torch.equal(nlz.mins, bounds[..., 1:2, :] - bounds[..., 0:1, :])
167167
)
168+
# with grad
169+
bounds.requires_grad = True
170+
bounds = bounds * 2
171+
self.assertIsNotNone(bounds.grad_fn)
172+
nlz = Normalize(d=2, bounds=bounds)
173+
# Set learn_coefficients=True for testing.
174+
nlz.learn_coefficients = True
175+
# We have grad in train mode.
176+
self.assertIsNotNone(nlz.coefficient.grad_fn)
177+
self.assertIsNotNone(nlz.offset.grad_fn)
178+
# Grad is detached in eval mode.
179+
nlz.eval()
180+
self.assertIsNone(nlz.coefficient.grad_fn)
181+
self.assertIsNone(nlz.offset.grad_fn)
168182

169183
# basic init, provided indices
170184
with self.assertRaises(ValueError):
@@ -326,6 +340,18 @@ def test_normalize(self):
326340
nlz10 = Normalize(d=3, batch_shape=batch_shape, indices=[0, 2])
327341
self.assertFalse(nlz9.equals(nlz10))
328342

343+
# test with grad
344+
nlz = Normalize(d=1)
345+
X.requires_grad = True
346+
X = X * 2
347+
self.assertIsNotNone(X.grad_fn)
348+
nlz(X)
349+
self.assertIsNotNone(nlz.coefficient.grad_fn)
350+
self.assertIsNotNone(nlz.offset.grad_fn)
351+
nlz.eval()
352+
self.assertIsNone(nlz.coefficient.grad_fn)
353+
self.assertIsNone(nlz.offset.grad_fn)
354+
329355
def test_standardize(self):
330356
for dtype in (torch.float, torch.double):
331357
# basic init

0 commit comments

Comments
 (0)