Skip to content

Commit 288aa87

Browse files
blzhengpytorchmergebot
authored andcommitted
[Inductor][CPU] disable bernoulli_p decomposition (pytorch#143460)
Fix pytorch#142853 `fallback_random=True` should cause RNG to match between compile/eager (by having compile fall back to eager for RNG ops), but the `bernoulli_p` decompose function is not fully consistent with the eager CPU implementation. We remove the decomp and keep the version for` fallback_random=False`. Pull Request resolved: pytorch#143460 Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5, https://github.com/jansel
1 parent fd8b217 commit 288aa87

File tree

4 files changed

+22
-26
lines changed

4 files changed

+22
-26
lines changed

test/expect/HasDecompTest.test_has_decomposition.expect

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,7 @@ aten::bernoulli.Tensor
704704
aten::bernoulli.Tensor_out
705705
aten::bernoulli.float_out
706706
aten::bernoulli.out
707+
aten::bernoulli.p
707708
aten::bernoulli_.Tensor
708709
aten::bernoulli_.float
709710
aten::bincount

test/inductor/test_cpu_repro.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4943,6 +4943,27 @@ def forward(self, context_layer, hidden_states):
49434943
torch.compile(converted_model)(*example_batch)
49444944
check_metrics_vec_kernel_count(3)
49454945

4946+
def test_dropout(self):
4947+
class Model(nn.Module):
4948+
def __init__(self, dim):
4949+
super().__init__()
4950+
self.dropout = eval(f"nn.Dropout{dim}d(p=0.5)")
4951+
4952+
def forward(self, x):
4953+
torch.manual_seed(0)
4954+
x = self.dropout(x)
4955+
return x
4956+
4957+
for dim in [1, 2, 3]:
4958+
model = Model(dim)
4959+
torch.manual_seed(0)
4960+
shape = [1, 3] + [256] * dim
4961+
x = torch.randn(*shape)
4962+
output = model(x)
4963+
c_model = torch.compile(model)
4964+
c_output = c_model(x)
4965+
self.assertTrue(torch.allclose(output, c_output))
4966+
49464967

49474968
if __name__ == "__main__":
49484969
from torch._inductor.test_case import run_tests

test/test_decomp.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -608,17 +608,6 @@ def test_uniform(self, device):
608608
res = torch._decomp.decompositions.uniform(x, low=low, high=high)
609609
self.assertEqual(ref, res)
610610

611-
def test_bernoulli_p(self, device):
612-
p = 0.3
613-
input_t = torch.rand(100, 100)
614-
torch.manual_seed(123)
615-
ref = torch.ops.aten.bernoulli.p(input_t, p)
616-
torch.manual_seed(123)
617-
res = torch._decomp.decompositions.bernoulli_p(input_t, p)
618-
ref_p = ref.sum() / torch.prod(torch.tensor(ref.size()))
619-
res_p = res.sum() / torch.prod(torch.tensor(res.size()))
620-
self.assertEqual(ref_p, res_p, atol=0.06 * p, rtol=0.06)
621-
622611
def test_bernoulli_default(self, device):
623612
p = 0.3
624613
p_t = p * torch.ones(5, 5)

torch/_decomp/decompositions.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5117,21 +5117,6 @@ def bernoulli(
51175117
return p
51185118

51195119

5120-
@register_decomposition(aten.bernoulli.p)
5121-
def bernoulli_p(self, p, *, generator: Optional[torch.Generator] = None):
5122-
if generator is None:
5123-
raw_p = torch.rand(self.size(), dtype=torch.float32, device=self.device)
5124-
else:
5125-
raw_p = torch.rand(
5126-
self.size(),
5127-
generator=generator,
5128-
dtype=self.float32,
5129-
device=self.device,
5130-
)
5131-
p = (raw_p < p).to(self.dtype)
5132-
return p
5133-
5134-
51355120
def isin_default(elements, test_elements, *, invert=False):
51365121
if elements.numel() == 0:
51375122
return torch.empty_like(elements, dtype=torch.bool)

0 commit comments

Comments
 (0)