Skip to content

Commit 7f3aa92

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Fix pickle error in TorchPosterior (meta-pytorch#1644)
Summary: ## Motivation Fixes meta-pytorch#1639. The issue was that pickle uses `__getstate__` and `__setstate__` to save / load the objects. In `TorchPosterior`, `__getattr__` was taking precedence, so instead of calling `TorchPosterior.__set/getstate__` it was calling `self.distribution.__set/getstate__`. While loading the posterior, we start with a `TorchPosterior` that does not have any attributes --including `distribution`--, so calling `self.distribution` was leading to an infinite loop of calling `TorchPosterior` with `name=distribution` (since it calls `getattr(self.distribution, name)`). Pull Request resolved: meta-pytorch#1644 Test Plan: Units. Reviewed By: Balandat Differential Revision: D42732823 Pulled By: saitcakmak fbshipit-source-id: 20183a7d5368a1e92dc36cd9726d0caf33f791d5
1 parent 67c4f40 commit 7f3aa92

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

botorch/posteriors/torch.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from __future__ import annotations
1212

13-
from typing import Any, Optional
13+
from typing import Any, Dict, Optional
1414

1515
import torch
1616
from botorch.posteriors.posterior import Posterior
@@ -78,6 +78,20 @@ def __getattr__(self, name: str) -> Any:
7878
"""
7979
return getattr(self.distribution, name)
8080

81+
def __getstate__(self) -> Dict[str, Any]:
82+
r"""A minimal utility to support pickle protocol.
83+
84+
Pickle uses `__get/setstate__` to serialize / deserialize the objects.
85+
Since we define `__getattr__` above, it takes precedence over these
86+
methods, and we end up in an infinite loop unless we also define
87+
`__getstate__` and `__setstate__`.
88+
"""
89+
return self.__dict__
90+
91+
def __setstate__(self, d: Dict[str, Any]) -> None:
92+
r"""A minimal utility to support pickle protocol."""
93+
self.__dict__ = d
94+
8195
def quantile(self, value: Tensor) -> Tensor:
8296
r"""Compute quantiles of the distribution.
8397

test/posteriors/test_torch_posterior.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
# This source code is licensed under the MIT license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import os
8+
import tempfile
9+
import unittest
710

811
import torch
912
from botorch.posteriors.torch import TorchPosterior
@@ -55,3 +58,20 @@ def test_torch_posterior(self):
5558
[posterior.distribution.log_prob(q).exp() for q in q_value], dim=0
5659
)
5760
self.assertAllClose(posterior.density(q_value), expected)
61+
62+
@unittest.skipIf(os.name == "nt", "Pickle test is not supported on Windows.")
63+
def test_pickle(self) -> None:
64+
for dtype in (torch.float, torch.double):
65+
tkwargs = {"dtype": dtype, "device": self.device}
66+
posterior = TorchPosterior(Exponential(rate=torch.rand(1, 2, **tkwargs)))
67+
with tempfile.NamedTemporaryFile() as tmp_file:
68+
torch.save(posterior, tmp_file.name)
69+
loaded_posterior = torch.load(tmp_file.name)
70+
self.assertEqual(posterior.dtype, loaded_posterior.dtype)
71+
self.assertEqual(posterior.device, loaded_posterior.device)
72+
self.assertTrue(
73+
torch.equal(
74+
posterior.distribution.rate,
75+
loaded_posterior.distribution.rate,
76+
)
77+
)

0 commit comments

Comments
 (0)