Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 24 additions & 10 deletions python/cugraph-pyg/cugraph_pyg/data/graph_store.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION.
# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0

import os
import warnings

import numpy as np
import cupy
Expand All @@ -15,7 +16,7 @@
from cugraph_pyg.tensor import DistTensor, DistMatrix
from cugraph_pyg.tensor.utils import has_nvlink_network, is_empty

from typing import Union, Optional, List, Dict, Tuple
from typing import Union, Optional, List, Dict, Tuple, Callable

# cudf is an optional dependency. It is only imported here for typing.
cudf = import_optional("cudf")
Expand Down Expand Up @@ -70,7 +71,7 @@ def __clear_graph(self):
self.__graph = None
self.__vertex_offsets = None
self.__weight_attr = None
self.__etime_attr = None
self.__time_attr = None
self.__numeric_edge_types = None

def _put_edge_index(
Expand Down Expand Up @@ -317,27 +318,35 @@ def _vertex_offset_array(self) -> "torch.Tensor":
def is_homogeneous(self) -> bool:
return len(self._vertex_offsets) == 1

def _set_etime_attr(self, attr: Tuple["torch_geometric.data.FeatureStore", str]):
if attr != self.__etime_attr:
def _set_time_attr(self, attr: Tuple["torch_geometric.data.FeatureStore", str]):
if attr != self.__time_attr:
weight_attr = self.__weight_attr
self.__clear_graph()
self.__etime_attr = attr
self.__time_attr = attr
self.__weight_attr = weight_attr

def _set_weight_attr(self, attr: Tuple["torch_geometric.data.FeatureStore", str]):
if attr != self.__weight_attr:
etime_attr = self.__etime_attr
time_attr = self.__time_attr
self.__clear_graph()
self.__weight_attr = attr
self.__etime_attr = etime_attr
self.__time_attr = time_attr

def _get_ntime_func(
self,
) -> Optional[Callable[[str, "torch.Tensor"], "torch.Tensor"]]:
if self.__time_attr is None:
return None
feature_store, attr_name = self.__time_attr
return lambda node_type, node_id: feature_store[node_type, attr_name][node_id]

def __get_etime_tensor(
self,
sorted_keys: List[Tuple[str, str, str]],
start_offsets: "torch.Tensor",
num_edges_t: "torch.Tensor",
):
feature_store, attr_name = self.__etime_attr
feature_store, attr_name = self.__time_attr
etimes = []
for i, et in enumerate(sorted_keys):
ix = torch.arange(
Expand Down Expand Up @@ -498,7 +507,12 @@ def __get_edgelist(self):
sorted_keys, start_offsets.cpu(), num_edges_t.cpu()
).cuda()

if self.__etime_attr is not None:
if self.__time_attr is not None:
warnings.warn(
"cuGraph-PyG currently supports only edge-based temporal sampling."
" Node times (if present) can still be used for negative sampling."
)
# TODO if node times are present, do node-based temporal sampling instead.
d["etime"] = self.__get_etime_tensor(
sorted_keys, start_offsets.cpu(), num_edges_t.cpu()
).cuda()
Expand Down
41 changes: 36 additions & 5 deletions python/cugraph-pyg/cugraph_pyg/examples/movielens_mnmg.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0

import os
Expand Down Expand Up @@ -92,6 +92,12 @@ def cugraph_pyg_from_heterodata(data):

feature_store["user", "x", None] = data["user"].x
feature_store["movie", "x", None] = data["movie"].x
feature_store[("user", "rates", "movie"), "time", None] = data[
"user", "rates", "movie"
].time
feature_store[("movie", "rev_rates", "user"), "time", None] = data[
"user", "rates", "movie"
].time

return feature_store, graph_store

Expand Down Expand Up @@ -136,7 +142,18 @@ def preprocess_and_partition(data, edge_path, features_path, meta_path):
fx,
os.path.join(movie_path, f"rank={r}.pt"),
)

time_path = os.path.join(features_path, "time")
os.makedirs(
time_path,
exist_ok=True,
)
for r, time in enumerate(
torch.tensor_split(data["user", "movie"].time, world_size)
):
torch.save(
time,
os.path.join(time_path, f"rank={r}.pt"),
)
print("Writing metadata...")
meta = {
"num_nodes": {
Expand Down Expand Up @@ -190,6 +207,10 @@ def load_partitions(edge_path, features_path, meta_path):
],
dim=1,
)
data["user", "rates", "movie"].time = torch.load(
os.path.join(features_path, "time", f"rank={rank}.pt"),
weights_only=True,
)

label_dict = {
"train": torch.randperm(ei["train"].shape[1]),
Expand Down Expand Up @@ -398,8 +419,19 @@ def test(test_loader, model):
feature_store, graph_store = cugraph_pyg_from_heterodata(data)
eli_train = data["user", "rates", "movie"].edge_index[:, label_dict["train"]]
eli_test = data["user", "rates", "movie"].edge_index[:, label_dict["test"]]
time_train = data["user", "rates", "movie"].time[label_dict["train"]]
num_nodes = {"user": data["user"].num_nodes, "movie": data["movie"].num_nodes}

# Set node times to 0
feature_store["user", "time", None] = torch.tensor_split(
torch.zeros(data["user"].num_nodes, dtype=torch.int64, device=device),
world_size,
)[global_rank]
feature_store["movie", "time", None] = torch.tensor_split(
torch.zeros(data["movie"].num_nodes, dtype=torch.int64, device=device),
world_size,
)[global_rank]

# Extract feature dimensions
num_features = {
"user": data["user"].x.shape[-1] if data["user"].x is not None else 1,
Expand All @@ -417,17 +449,16 @@ def test(test_loader, model):
("movie", "rev_rates", "user"): [5, 5, 5],
},
batch_size=256,
# time_attr='time',
shuffle=True,
drop_last=True,
# temporal_strategy='last',
)

from cugraph_pyg.loader import LinkNeighborLoader

train_loader = LinkNeighborLoader(
edge_label_index=(("user", "rates", "movie"), eli_train),
# edge_label_time=time[train_index] - 1, # No leakage.
edge_label_time=time_train - 1, # No leakage.
time_attr="time",
neg_sampling=dict(mode="binary", amount=2),
**kwargs,
)
Expand Down
14 changes: 7 additions & 7 deletions python/cugraph-pyg/cugraph_pyg/loader/link_neighbor_loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION.
# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0

import warnings
Expand Down Expand Up @@ -129,10 +129,10 @@ def __init__(
all workers. If not provided, it will be automatically
calculated.
See cugraph_pyg.sampler.BaseDistributedSampler.
temporal_comparison: str (optional, default='monotonically decreasing')
temporal_comparison: str (optional, default='monotonically_decreasing')
The comparison operator for temporal sampling
('strictly increasing', 'monotonically increasing',
'strictly decreasing', 'monotonically decreasing', 'last').
('strictly_increasing', 'monotonically_increasing',
'strictly_decreasing', 'monotonically_decreasing', 'last').
Note that this should be 'last' for temporal_strategy='last'.
See cugraph_pyg.sampler.BaseDistributedSampler.
**kwargs
Expand All @@ -142,7 +142,7 @@ def __init__(
subgraph_type = torch_geometric.sampler.base.SubgraphType(subgraph_type)

if temporal_comparison is None:
temporal_comparison = "monotonically decreasing"
temporal_comparison = "monotonically_decreasing"

if not directed:
subgraph_type = torch_geometric.sampler.base.SubgraphType.induced
Expand Down Expand Up @@ -182,15 +182,15 @@ def __init__(

is_temporal = (edge_label_time is not None) and (time_attr is not None)

if (edge_label_time is None) != (time_attr is None):
if not is_temporal and (edge_label_time is not None or time_attr is not None):
warnings.warn(
"Edge-based temporal sampling requires that both edge_label_time and time_attr are provided. Defaulting to non-temporal sampling."
)

if weight_attr is not None:
graph_store._set_weight_attr((feature_store, weight_attr))
if is_temporal:
graph_store._set_etime_attr((feature_store, time_attr))
graph_store._set_time_attr((feature_store, time_attr))

if isinstance(num_neighbors, dict):
sorted_keys, _, _ = graph_store._numeric_edge_types
Expand Down
12 changes: 6 additions & 6 deletions python/cugraph-pyg/cugraph_pyg/loader/neighbor_loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION.
# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0

import warnings
Expand Down Expand Up @@ -122,10 +122,10 @@ def __init__(
all workers. If not provided, it will be automatically
calculated.
See cugraph_pyg.sampler.BaseDistributedSampler.
temporal_comparison: str (optional, default='monotonically decreasing')
temporal_comparison: str (optional, default='monotonically_decreasing')
The comparison operator for temporal sampling
('strictly increasing', 'monotonically increasing',
'strictly decreasing', 'monotonically decreasing', 'last').
('strictly_increasing', 'monotonically_increasing',
'strictly_decreasing', 'monotonically_decreasing', 'last').
Note that this should be 'last' for temporal_strategy='last'.
See cugraph_pyg.sampler.BaseDistributedSampler.
**kwargs
Expand All @@ -135,7 +135,7 @@ def __init__(
subgraph_type = torch_geometric.sampler.base.SubgraphType(subgraph_type)

if temporal_comparison is None:
temporal_comparison = "monotonically decreasing"
temporal_comparison = "monotonically_decreasing"

if not directed:
subgraph_type = torch_geometric.sampler.base.SubgraphType.induced
Expand Down Expand Up @@ -176,7 +176,7 @@ def __init__(
is_temporal = time_attr is not None

if is_temporal:
graph_store._set_etime_attr((feature_store, time_attr))
graph_store._set_time_attr((feature_store, time_attr))

if input_time is None:
input_type, input_nodes, _ = (
Expand Down
22 changes: 19 additions & 3 deletions python/cugraph-pyg/cugraph_pyg/sampler/distributed_sampler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION.
# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0

import warnings
Expand Down Expand Up @@ -556,9 +556,25 @@ def __sample_from_edges_func(
leftover_time = leftover_time[lyi]

lz = torch.sort(lyi)[1]
leftover_seeds, lui = leftover_seeds.unique_consecutive(return_inverse=True)
if leftover_time is not None:
leftover_time = leftover_time[lui]
if leftover_seeds.numel() == 0:
assert leftover_time.numel() == 0, (
"Leftover time should be empty if leftover seeds are empty"
)
leftover_seeds_unique_mask = torch.tensor(
[], device="cuda", dtype=torch.bool
)
else:
leftover_seeds_unique_mask = torch.concat(
[
torch.tensor([True], device="cuda"),
leftover_seeds[1:] != leftover_seeds[:-1],
]
)
leftover_seeds, lui = leftover_seeds.unique_consecutive(return_inverse=True)
leftover_time = leftover_time[leftover_seeds_unique_mask]
else:
leftover_seeds, lui = leftover_seeds.unique_consecutive(return_inverse=True)
leftover_inv = lui[lz]

if leftover_seeds.numel() > 0:
Expand Down
25 changes: 19 additions & 6 deletions python/cugraph-pyg/cugraph_pyg/sampler/sampler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION.
# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0

from typing import Optional, Iterator, Union, Dict, Tuple, List

from math import ceil

from cugraph_pyg.utils.imports import import_optional
from cugraph_pyg.sampler.distributed_sampler import DistributedNeighborSampler

Expand Down Expand Up @@ -797,7 +799,7 @@ def sample_from_nodes(
def sample_from_edges(
self,
index: "torch_geometric.sampler.EdgeSamplerInput",
neg_sampling: Optional["torch_geometric.sampler.NegativeSampling"],
neg_sampling: Optional["torch_geometric.sampler.NegativeSampling"] = None,
**kwargs,
) -> Iterator[
Union[
Expand All @@ -808,19 +810,23 @@ def sample_from_edges(
src = index.row
dst = index.col
input_id = index.input_id
input_time = index.time

# TODO ensure this is handled correctly when disjoint sampling is implemented.
node_time = self.__graph_store._get_ntime_func()

neg_batch_size = 0
if neg_sampling:
# Sample every negative subset at once.
# TODO handle temporal sampling (node_time)
src_neg, dst_neg = neg_sample(
self.__graph_store,
index.row,
index.col,
index.input_type,
self.__batch_size,
neg_sampling,
None, # src_time,
None, # src_node_time,
index.time,
node_time,
)
if neg_sampling.is_binary():
src, _ = neg_cat(src.cuda(), src_neg, self.__batch_size)
Expand All @@ -834,6 +840,13 @@ def sample_from_edges(
src, _ = neg_cat(scu, scu[per], self.__batch_size)
dst, neg_batch_size = neg_cat(dst.cuda(), dst_neg, self.__batch_size)

if node_time is not None and input_time is not None:
input_time, _ = neg_cat(
input_time.repeat_interleave(int(ceil(neg_sampling.amount))).cuda(),
input_time.cuda(),
self.__batch_size,
)
Comment on lines +843 to +848
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

neg_cat arguments are swapped for input_time

neg_cat(seed_pos, seed_neg, batch_size) expects positives as the first argument. Here, the expanded version (N_pos × ceil(amount) elements) is passed first, but the positives should match src/dst (which have N_pos elements in the first arg). This swap causes two problems:

  1. num_batches inside neg_cat is computed from the first arg's size, so input_time ends up with a different batch structure than src/dst.
  2. When amount is non-integer, the total element count of the resulting input_time won't match the combined src/dst size.

The correct call should be:

input_time, _ = neg_cat(
    input_time.cuda(),
    input_time.repeat_interleave(int(ceil(neg_sampling.amount))).cuda(),
    self.__batch_size,
)


# Concatenate -1s so the input id tensor lines up and can
# be processed by the dist sampler.
# When loading the output batch, '-1' will be dropped.
Expand All @@ -858,7 +871,7 @@ def sample_from_edges(
reader = self.__sampler.sample_from_edges(
torch.stack([src, dst]), # reverse of usual convention
input_id=input_id,
input_time=index.time,
input_time=input_time,
input_label=index.label,
batch_size=self.__batch_size + neg_batch_size,
metadata=metadata,
Expand Down
Loading
Loading