Skip to content

Commit 8b3b67f

Browse files
[BUG] Mask out unwanted vertices during negative sampling (#303)
Masks out unwanted vertices during heterogeneous negative sampling, which was previously not being done. This caused de-offsetting to fail and produce negative values for `edge_label_index`, which exposed the bug. Anything sampled with these negatives edge would have been invalid anyways even if de-offsetting returned positive values. Also fixes a bug affecting triplet sampling by concatenating from a random subset of src instead of dst. Closes #304 Partially resolves nvbug#5502562 Authors: - Alex Barghi (https://github.com/alexbarghi-nv) Approvers: - Tingyu Wang (https://github.com/tingyu66) URL: #303
1 parent 525ca06 commit 8b3b67f

File tree

3 files changed

+177
-7
lines changed

3 files changed

+177
-7
lines changed

python/cugraph-pyg/cugraph_pyg/sampler/sampler.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -818,6 +818,7 @@ def sample_from_edges(
818818
self.__graph_store,
819819
index.row,
820820
index.col,
821+
index.input_type,
821822
self.__batch_size,
822823
neg_sampling,
823824
None, # src_time,
@@ -826,9 +827,13 @@ def sample_from_edges(
826827
if neg_sampling.is_binary():
827828
src, _ = neg_cat(src.cuda(), src_neg, self.__batch_size)
828829
else:
829-
# triplet, cat dst to src so length is the same; will
830-
# result in the same set of unique vertices
831-
src, _ = neg_cat(src.cuda(), dst_neg, self.__batch_size)
830+
# triplet, cat random subset of src to src so length is the
831+
# same; will result in the same set of unique vertices
832+
scu = src.cuda()
833+
per = torch.randint(
834+
0, scu.numel(), (dst_neg.numel(),), device=scu.device
835+
)
836+
src, _ = neg_cat(scu, scu[per], self.__batch_size)
832837
dst, neg_batch_size = neg_cat(dst.cuda(), dst_neg, self.__batch_size)
833838

834839
# Concatenate -1s so the input id tensor lines up and can

python/cugraph-pyg/cugraph_pyg/sampler/sampler_utils.py

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def neg_sample(
7979
graph_store: GraphStore,
8080
seed_src: "torch.Tensor",
8181
seed_dst: "torch.Tensor",
82+
input_type: Tuple[str, str, str],
8283
batch_size: int,
8384
neg_sampling: "torch_geometric.sampler.NegativeSampling",
8485
time: "torch.Tensor",
@@ -91,22 +92,85 @@ def neg_sample(
9192
except AttributeError:
9293
src_weight = neg_sampling.weight
9394
dst_weight = neg_sampling.weight
94-
unweighted = src_weight is None and dst_weight is None
9595

9696
# Require at least one negative edge per batch
9797
num_neg = max(
9898
int(ceil(neg_sampling.amount * seed_src.numel())),
9999
int(ceil(seed_src.numel() / batch_size)),
100100
)
101101

102+
# The weights need to match the expected number of nodes
103+
if graph_store.is_homogeneous:
104+
num_src_nodes = num_dst_nodes = list(graph_store._num_vertices().values())[0]
105+
else:
106+
num_src_nodes = graph_store._num_vertices()[input_type[0]]
107+
num_dst_nodes = graph_store._num_vertices()[input_type[2]]
108+
109+
if src_weight is not None and dst_weight is not None:
110+
if src_weight.dtype != dst_weight.dtype:
111+
raise ValueError(
112+
f"The 'src_weight' and 'dst_weight' attributes need to have the same"
113+
f" dtype (got {src_weight.dtype} and {dst_weight.dtype})"
114+
)
115+
weight_dtype = (
116+
torch.float32
117+
if (src_weight is None and dst_weight is None)
118+
else (src_weight.dtype if src_weight is not None else dst_weight.dtype)
119+
)
120+
121+
if src_weight is None:
122+
src_weight = torch.ones(num_src_nodes, dtype=weight_dtype, device="cuda")
123+
else:
124+
if src_weight.numel() != num_src_nodes:
125+
raise ValueError(
126+
f"The 'src_weight' attribute needs to match the number of source nodes"
127+
f" {num_src_nodes} (got {src_weight.numel()})"
128+
)
129+
130+
if dst_weight is None:
131+
dst_weight = torch.ones(num_dst_nodes, dtype=weight_dtype, device="cuda")
132+
else:
133+
if dst_weight.numel() != num_dst_nodes:
134+
raise ValueError(
135+
f"The 'dst_weight' attribute needs to match the number of destination"
136+
f" nodes {num_dst_nodes} (got {dst_weight.numel()})"
137+
)
138+
139+
# If the graph is heterogeneous, the weights need to be concatenated together
140+
# and offsetted.
141+
if not graph_store.is_homogeneous:
142+
if input_type[0] != input_type[2]:
143+
vertices = torch.concat(
144+
[
145+
torch.arange(num_src_nodes, dtype=torch.int64, device="cuda")
146+
+ graph_store._vertex_offsets[input_type[0]],
147+
torch.arange(num_dst_nodes, dtype=torch.int64, device="cuda")
148+
+ graph_store._vertex_offsets[input_type[2]],
149+
]
150+
)
151+
else:
152+
vertices = (
153+
torch.arange(num_src_nodes, dtype=torch.int64, device="cuda")
154+
+ graph_store._vertex_offsets[input_type[0]]
155+
)
156+
157+
src_weight = torch.concat(
158+
[src_weight, torch.zeros(num_dst_nodes, dtype=weight_dtype, device="cuda")]
159+
)
160+
dst_weight = torch.concat(
161+
[torch.zeros(num_src_nodes, dtype=weight_dtype, device="cuda"), dst_weight]
162+
)
163+
elif src_weight is None and dst_weight is None:
164+
vertices = None
165+
else:
166+
vertices = torch.arange(num_src_nodes, dtype=torch.int64, device="cuda")
167+
102168
if node_time is None:
103169
result_dict = pylibcugraph.negative_sampling(
104170
graph_store._resource_handle,
105171
graph_store._graph,
106172
num_neg,
107-
vertices=None
108-
if unweighted
109-
else cupy.arange(src_weight.numel(), dtype="int64"),
173+
vertices=None if vertices is None else cupy.asarray(vertices),
110174
src_bias=None if src_weight is None else cupy.asarray(src_weight),
111175
dst_bias=None if dst_weight is None else cupy.asarray(dst_weight),
112176
remove_duplicates=False,

python/cugraph-pyg/cugraph_pyg/tests/loader/test_neighbor_loader.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,3 +597,104 @@ def test_neighbor_loader_hetero_linkpred_bidirectional_three_types(
597597
assert (r_i == eli_i).all()
598598

599599
assert i == 7
600+
601+
602+
@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available")
603+
@pytest.mark.sg
604+
@pytest.mark.parametrize("batch_size", [1, 2])
605+
@pytest.mark.parametrize("neg_sampling_mode", ["binary", "triplet"])
606+
@pytest.mark.parametrize("amount", [1, 2])
607+
def test_link_neighbor_loader_hetero_negative_sampling(
608+
batch_size, neg_sampling_mode, amount, single_pytorch_worker
609+
):
610+
"""
611+
Test negative sampling for heterogeneous graphs with different edge types.
612+
"""
613+
# Create a heterogeneous graph with paper-author relationships
614+
src_paper = torch.tensor([0, 1, 2, 4, 3, 4, 5, 5]) # paper
615+
dst_paper = torch.tensor([4, 5, 4, 3, 2, 1, 0, 1]) # paper
616+
617+
asrc = torch.tensor([0, 1, 2, 3, 3, 0]) # author
618+
adst = torch.tensor([0, 1, 2, 3, 4, 5]) # paper
619+
620+
num_authors = 4
621+
num_papers = 6
622+
623+
graph_store = GraphStore()
624+
feature_store = FeatureStore()
625+
626+
# Add paper-paper citations
627+
graph_store[("paper", "cites", "paper"), "coo", False, (num_papers, num_papers)] = [
628+
src_paper,
629+
dst_paper,
630+
]
631+
# Add author-paper relationships
632+
graph_store[
633+
("author", "writes", "paper"), "coo", False, (num_authors, num_papers)
634+
] = [asrc, adst]
635+
636+
# Create edge label index for author-paper relationships
637+
edge_label_index = torch.stack([asrc, adst])
638+
639+
# Test both binary and triplet negative sampling
640+
if neg_sampling_mode == "binary":
641+
neg_sampling = torch_geometric.sampler.NegativeSampling(
642+
"binary", amount=float(amount)
643+
)
644+
else:
645+
neg_sampling = torch_geometric.sampler.NegativeSampling(
646+
"triplet", amount=float(amount)
647+
)
648+
649+
loader = cugraph_pyg.loader.LinkNeighborLoader(
650+
(feature_store, graph_store),
651+
num_neighbors={
652+
("paper", "cites", "paper"): [2, 2],
653+
("author", "writes", "paper"): [2, 2],
654+
},
655+
edge_label_index=(("author", "writes", "paper"), edge_label_index),
656+
batch_size=batch_size,
657+
neg_sampling=neg_sampling,
658+
shuffle=False,
659+
)
660+
661+
# Test that the loader produces batches with proper negative sampling
662+
for i, batch in enumerate(loader):
663+
# Check that we have the expected edge label index structure
664+
assert [("author", "writes", "paper")] == list(
665+
batch.edge_label_index_dict.keys()
666+
)
667+
assert [("author", "writes", "paper")] == list(batch.edge_label_dict.keys())
668+
669+
# Should have both positive (1.0) and negative (0.0) labels
670+
edge_labels = batch["author", "writes", "paper"].edge_label
671+
assert torch.any(edge_labels == 1.0)
672+
assert torch.any(edge_labels == 0.0)
673+
assert (edge_labels == 0.0).sum() == amount * (edge_labels == 1.0).sum()
674+
675+
# Verify that the edge label index has the correct shape
676+
edge_label_idx = batch["author", "writes", "paper"].edge_label_index
677+
assert edge_label_idx.shape[0] == 2 # Should be [2, num_edges]
678+
assert edge_label_idx.shape[1] > 0 # Should have some edges
679+
680+
# Verify that the edge labels correspond to the edge label index
681+
assert edge_labels.shape[0] == edge_label_idx.shape[1]
682+
683+
# Check that node IDs are valid
684+
assert batch["author"].n_id.numel() > 0
685+
assert batch["paper"].n_id.numel() > 0
686+
687+
# Verify that edge label index uses valid node IDs
688+
author_n_ids = batch["author"].n_id
689+
paper_n_ids = batch["paper"].n_id
690+
691+
# All source nodes in edge_label_index should be in author.n_id
692+
src_nodes = edge_label_idx[0]
693+
assert torch.all(torch.isin(src_nodes.cpu(), torch.arange(len(author_n_ids))))
694+
695+
# All destination nodes in edge_label_index should be in paper.n_id
696+
dst_nodes = edge_label_idx[1]
697+
assert torch.all(torch.isin(dst_nodes.cpu(), torch.arange(len(paper_n_ids))))
698+
699+
# Verify we processed all batches
700+
assert i >= 0 # At least one batch should be processed

0 commit comments

Comments
 (0)