Skip to content

Commit d63eb9c

Browse files
committed
Remaining flake8 formatting errors
1 parent 0c12788 commit d63eb9c

File tree

2 files changed

+22
-9
lines changed

2 files changed

+22
-9
lines changed

test/test_logsumexp.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,20 @@
99
SUPPORTED_FLOAT_DTYPES = {torch.float32, torch.float64}
1010

1111

12-
@pytest.mark.parametrize('dtype,device', product(SUPPORTED_FLOAT_DTYPES, devices))
12+
@pytest.mark.parametrize('dtype,device',
13+
product(SUPPORTED_FLOAT_DTYPES, devices))
1314
def test_logsumexp(dtype, device):
1415
src = tensor([0.5, 0, 0.5, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
1516
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
1617

1718
out = scatter_logsumexp(src, index)
1819

19-
idx0 = torch.logsumexp(torch.tensor([0.5, 0.5], dtype=dtype), dim=-1).tolist()
20-
idx1 = torch.logsumexp(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1).tolist()
20+
idx0 = torch.logsumexp(
21+
torch.tensor([0.5, 0.5], dtype=dtype),
22+
dim=-1).tolist()
23+
idx1 = torch.logsumexp(
24+
torch.tensor([0, -2.1, 3.2], dtype=dtype),
25+
dim=-1).tolist()
2126
idx2 = 7 # Single element
2227
idx3 = torch.finfo(dtype).min # Empty index, returns yield value
2328
idx4 = -1 # logsumexp with -inf is the identity

test/test_softmax.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,20 @@
1010
SUPPORTED_FLOAT_DTYPES = {torch.float32, torch.float64}
1111

1212

13-
@pytest.mark.parametrize('dtype,device', product(SUPPORTED_FLOAT_DTYPES, devices))
13+
@pytest.mark.parametrize('dtype,device',
14+
product(SUPPORTED_FLOAT_DTYPES, devices))
1415
def test_log_softmax(dtype, device):
15-
src = tensor([0.25, 0, 0.25, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
16+
src = tensor([0.25, 0, 0.25, -2.1, 3.2, 7, -1, float('-inf')],
17+
dtype, device)
1618
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
1719

1820
out = scatter_log_softmax(src, index)
1921

2022
# Expected results per index
2123
idx0 = [np.log(0.5), np.log(0.5)]
22-
idx1 = torch.log_softmax(torch.tensor([0.0, -2.1, 3.2], dtype=dtype), dim=-1).tolist()
24+
idx1 = torch.log_softmax(
25+
torch.tensor([0.0, -2.1, 3.2], dtype=dtype),
26+
dim=-1).tolist()
2327
idx2 = 0.0 # Single element, has logprob=0
2428
# index=3 is empty. Should not matter.
2529
idx4 = [0.0, float('-inf')] # log_softmax with -inf preserves the -inf
@@ -31,16 +35,20 @@ def test_log_softmax(dtype, device):
3135
)
3236

3337

34-
@pytest.mark.parametrize('dtype,device', product(SUPPORTED_FLOAT_DTYPES, devices))
38+
@pytest.mark.parametrize('dtype,device',
39+
product(SUPPORTED_FLOAT_DTYPES, devices))
3540
def test_softmax(dtype, device):
36-
src = tensor([0.25, 0, 0.25, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
41+
src = tensor([0.25, 0, 0.25, -2.1, 3.2, 7, -1, float('-inf')],
42+
dtype, device)
3743
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
3844

3945
out = scatter_softmax(src, index)
4046

4147
# Expected results per index
4248
idx0 = [0.5, 0.5]
43-
idx1 = torch.softmax(torch.tensor([0.0, -2.1, 3.2], dtype=dtype), dim=-1).tolist()
49+
idx1 = torch.softmax(
50+
torch.tensor([0.0, -2.1, 3.2], dtype=dtype),
51+
dim=-1).tolist()
4452
idx2 = 1 # Single element, has prob=1
4553
# index=3 is empty. Should not matter.
4654
idx4 = [1.0, 0.0] # softmax with -inf yields zero probability

0 commit comments

Comments
 (0)