44from .utils .unique import unique
55
66
7- def coalesce (index , value , m , n , op = 'add' , fill_value = 0 ):
7+ def coalesce (index , value , m , n , op = 'add' ):
88 """Row-wise sorts :obj:`value` and removes duplicate entries. Duplicate
99 entries are removed by scattering them together. For scattering, any
1010 operation of `"torch_scatter"<https://github.com/rusty1s/pytorch_scatter>`_
@@ -17,8 +17,6 @@ def coalesce(index, value, m, n, op='add', fill_value=0):
1717 n (int): The second dimension of corresponding dense matrix.
1818 op (string, optional): The scatter operation to use. (default:
1919 :obj:`"add"`)
20- fill_value (int, optional): The initial fill value of scatter
21- operation. (default: :obj:`0`)
2220
2321 :rtype: (:class:`LongTensor`, :class:`Tensor`)
2422 """
@@ -37,8 +35,7 @@ def coalesce(index, value, m, n, op='add', fill_value=0):
3735 index = torch .stack ([row [perm ], col [perm ]], dim = 0 )
3836
3937 op = getattr (torch_scatter , 'scatter_{}' .format (op ))
40- value = op (value , inv , 0 , None , perm .size (0 ), fill_value )
41- if isinstance (value , tuple ):
42- value = value [0 ]
38+ value = op (value , inv , 0 , None , perm .size (0 ))
39+ value = value [0 ] if isinstance (value , tuple ) else value
4340
4441 return index , value
0 commit comments