Skip to content

Commit 7c46799

Browse files
committed
doc fixes
1 parent 8594006 commit 7c46799

File tree

4 files changed

+20
-11
lines changed

4 files changed

+20
-11
lines changed

torch_scatter/composite/logsumexp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
2121
dim_size = out.size(dim)
2222
else:
2323
if dim_size is None:
24-
dim_size = int(index.max().item() + 1)
24+
dim_size = int(index.max()) + 1
2525

2626
size = src.size()
2727
size[dim] = dim_size

torch_scatter/scatter.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,23 @@
1010
except OSError:
1111
warnings.warn('Failed to load `scatter` binaries.')
1212

13-
def placeholder(src: torch.Tensor, index: torch.Tensor, dim: int,
14-
out: Optional[torch.Tensor],
15-
dim_size: Optional[int]) -> torch.Tensor:
13+
def scatter_placeholder(src: torch.Tensor, index: torch.Tensor, dim: int,
14+
out: Optional[torch.Tensor],
15+
dim_size: Optional[int]) -> torch.Tensor:
1616
raise ImportError
17+
return src
1718

18-
def arg_placeholder(src: torch.Tensor, index: torch.Tensor, dim: int,
19-
out: Optional[torch.Tensor], dim_size: Optional[int]
20-
) -> Tuple[torch.Tensor, torch.Tensor]:
19+
def scatter_with_arg_placeholder(src: torch.Tensor, index: torch.Tensor,
20+
dim: int, out: Optional[torch.Tensor],
21+
dim_size: Optional[int]
22+
) -> Tuple[torch.Tensor, torch.Tensor]:
2123
raise ImportError
24+
return src, index
2225

23-
torch.ops.torch_scatter.scatter_sum = placeholder
24-
torch.ops.torch_scatter.scatter_mean = placeholder
25-
torch.ops.torch_scatter.scatter_min = arg_placeholder
26-
torch.ops.torch_scatter.scatter_max = arg_placeholder
26+
torch.ops.torch_scatter.scatter_sum = scatter_placeholder
27+
torch.ops.torch_scatter.scatter_mean = scatter_placeholder
28+
torch.ops.torch_scatter.scatter_min = scatter_with_arg_placeholder
29+
torch.ops.torch_scatter.scatter_max = scatter_with_arg_placeholder
2730

2831

2932
@torch.jit.script

torch_scatter/segment_coo.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,19 @@ def segment_coo_placeholder(src: torch.Tensor, index: torch.Tensor,
1414
out: Optional[torch.Tensor],
1515
dim_size: Optional[int]) -> torch.Tensor:
1616
raise ImportError
17+
return src
1718

1819
def segment_coo_with_arg_placeholder(
1920
src: torch.Tensor, index: torch.Tensor,
2021
out: Optional[torch.Tensor],
2122
dim_size: Optional[int]) -> Tuple[torch.Tensor, torch.Tensor]:
2223
raise ImportError
24+
return src, index
2325

2426
def gather_coo_placeholder(src: torch.Tensor, index: torch.Tensor,
2527
out: Optional[torch.Tensor]) -> torch.Tensor:
2628
raise ImportError
29+
return src
2730

2831
torch.ops.torch_scatter.segment_sum_coo = segment_coo_placeholder
2932
torch.ops.torch_scatter.segment_mean_coo = segment_coo_placeholder

torch_scatter/segment_csr.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,18 @@
1313
def segment_csr_placeholder(src: torch.Tensor, indptr: torch.Tensor,
1414
out: Optional[torch.Tensor]) -> torch.Tensor:
1515
raise ImportError
16+
return src
1617

1718
def segment_csr_with_arg_placeholder(
1819
src: torch.Tensor, indptr: torch.Tensor,
1920
out: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
2021
raise ImportError
22+
return src, indptr
2123

2224
def gather_csr_placeholder(src: torch.Tensor, indptr: torch.Tensor,
2325
out: Optional[torch.Tensor]) -> torch.Tensor:
2426
raise ImportError
27+
return src
2528

2629
torch.ops.torch_scatter.segment_sum_csr = segment_csr_placeholder
2730
torch.ops.torch_scatter.segment_mean_csr = segment_csr_placeholder

0 commit comments

Comments
 (0)