Skip to content

Backport Temporal Negative Sampling Update/Fix to 26.02#424

Open
alexbarghi-nv wants to merge 1 commit intorapidsai:release/26.02from
alexbarghi-nv:hotfix-26.02-temporal-neg-sampling
Open

Backport Temporal Negative Sampling Update/Fix to 26.02#424
alexbarghi-nv wants to merge 1 commit intorapidsai:release/26.02from
alexbarghi-nv:hotfix-26.02-temporal-neg-sampling

Conversation

@alexbarghi-nv
Copy link
Member

Backports the fixes and changes required for temporal negative sampling to release 26.02.

@alexbarghi-nv alexbarghi-nv requested a review from a team as a code owner March 7, 2026 00:26
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 7, 2026

Greptile Summary

This PR backports temporal negative sampling support to release/26.02. The core changes implement node-time-aware negative edge generation in neg_sample(), rename __etime_attr to __time_attr (with a new _get_ntime_func() accessor), fix a leftover_time indexing bug in the distributed sampler's deduplication step, and update temporal_comparison default strings to use underscores.

Key issues found:

  • neg_cat arguments for input_time are swapped (sampler.py lines 844–848): the expanded tensor (N_pos × ceil(amount)) is passed as the positives arg and the original input_time (N_pos) as the negatives arg. This inverts the batch structure relative to src/dst and produces an incorrectly sized seed-time tensor passed to the distributed sampler.
  • Size mismatch for non-integer amount (sampler_utils.py lines 219–240): seed_time is expanded to ceil(amount) × N_pos elements, but src_neg has max(ceil(amount × N_pos), ceil(N_pos / batch_size)) elements. When amount is non-integer (e.g. 1.5) and N_pos > 1, these sizes diverge and the element-wise comparison crashes at runtime.
  • The leftover_time deduplication fix in distributed_sampler.py and the loader-level renaming changes look correct.

Confidence Score: 2/5

  • Not safe to merge — two logic bugs in the new temporal negative sampling path will cause incorrect results or runtime errors.
  • The distributed-sampler fix and renaming changes are clean, but the two core bugs in sampler.py and sampler_utils.py affect the primary feature being backported. The swapped neg_cat arguments produce wrong seed times for temporal neighborhood sampling, and the size mismatch crashes for non-integer amounts — both in the hot path of the new feature.
  • python/cugraph-pyg/cugraph_pyg/sampler/sampler.py and python/cugraph-pyg/cugraph_pyg/sampler/sampler_utils.py require fixes before merging.

Important Files Changed

Filename Overview
python/cugraph-pyg/cugraph_pyg/sampler/sampler.py Adds node_time function lookup and wires it into neg_sample; introduces input_time expansion for temporal negative sampling, but the neg_cat arguments for input_time are swapped (pos/neg reversed), causing incorrect batch structure and potential size mismatches.
python/cugraph-pyg/cugraph_pyg/sampler/sampler_utils.py Implements temporal negative sampling with retry logic; size mismatch bug exists when neg_sampling.amount is non-integer because seed_time is expanded by ceil(amount)N_pos while num_neg uses ceil(amountN_pos), causing a broadcast error in the validity mask comparison.
python/cugraph-pyg/cugraph_pyg/sampler/distributed_sampler.py Fixes leftover_time indexing after unique_consecutive by computing a boolean mask of first-occurrence positions before deduplication, correctly aligning times with deduplicated seeds.
python/cugraph-pyg/cugraph_pyg/data/graph_store.py Renames __etime_attr to __time_attr, adds _get_ntime_func() for node-time lookup, and emits a warning when time_attr is present (edge-only temporal sampling for now). Clean and correct.
python/cugraph-pyg/cugraph_pyg/loader/link_neighbor_loader.py Updates temporal_comparison default string to use underscores, fixes is_temporal warning condition, and routes time_attr through the renamed _set_time_attr. Looks correct.
python/cugraph-pyg/cugraph_pyg/loader/neighbor_loader.py Mirror of link_neighbor_loader changes: temporal_comparison default string updated and _set_etime_attr renamed to _set_time_attr. Correct.
python/cugraph-pyg/cugraph_pyg/examples/movielens_mnmg.py Enables temporal negative sampling in the MovieLens example (stores edge/node times, enables edge_label_time and time_attr). Uses inconsistent 2-tuple key ("user", "movie") to access edge data at one location instead of the canonical 3-tuple.
python/cugraph-pyg/cugraph_pyg/tests/loader/test_neighbor_loader.py Adds homogeneous and heterogeneous temporal negative sampling tests. Temporal constraint assertions use edge_label_time[i * batch_size] as the bound for all negatives in a batch, which is fragile when batch elements have different times, though it works for the constant-time test data used here.

Sequence Diagram

sequenceDiagram
    participant LNL as LinkNeighborLoader
    participant BS as BaseSampler
    participant NS as neg_sample()
    participant PLC as _call_plc_negative_sampling()
    participant GS as GraphStore
    participant DS as DistributedSampler

    LNL->>GS: _set_time_attr(feature_store, time_attr)
    LNL->>BS: sample_from_edges(index)
    BS->>GS: _get_ntime_func()
    GS-->>BS: node_time_func (lambda)
    BS->>NS: neg_sample(graph_store, src, dst, input_type, batch_size, neg_sampling, index.time, node_time_func)
    NS->>PLC: _call_plc_negative_sampling(num_neg, vertices, weights)
    PLC-->>NS: src_neg, dst_neg (raw)
    Note over NS: Filter by node_time <= seed_time<br/>(retry up to 5x)
    NS-->>BS: src_neg, dst_neg (temporally valid)
    BS->>BS: neg_cat(src, src_neg, batch_size)
    BS->>BS: neg_cat(dst, dst_neg, batch_size)
    Note over BS: ⚠️ input_time neg_cat args swapped
    BS->>DS: sample_from_edges(seeds, input_time, batch_size)
    DS->>DS: unique_consecutive with leftover_seeds_unique_mask fix
    DS-->>BS: sampled batches
    BS-->>LNL: HeteroSamplerOutput / SamplerOutput
Loading

Comments Outside Diff (1)

  1. python/cugraph-pyg/cugraph_pyg/examples/movielens_mnmg.py, line 122-124 (link)

    Inconsistent edge-type key format

    data["user", "movie"].time uses a 2-tuple key, while every other access in this file uses the canonical 3-tuple data["user", "rates", "movie"]. PyG supports shorthand lookup only when the relation name is unambiguous, but relying on it silently can break if a second edge type between the same node types is ever added.

Last reviewed commit: ed64042

Comment on lines +843 to +848
if node_time is not None and input_time is not None:
input_time, _ = neg_cat(
input_time.repeat_interleave(int(ceil(neg_sampling.amount))).cuda(),
input_time.cuda(),
self.__batch_size,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

neg_cat arguments are swapped for input_time

neg_cat(seed_pos, seed_neg, batch_size) expects positives as the first argument. Here, the expanded version (N_pos × ceil(amount) elements) is passed first, but the positives should match src/dst (which have N_pos elements in the first arg). This swap causes two problems:

  1. num_batches inside neg_cat is computed from the first arg's size, so input_time ends up with a different batch structure than src/dst.
  2. When amount is non-integer, the total element count of the resulting input_time won't match the combined src/dst size.

The correct call should be:

input_time, _ = neg_cat(
    input_time.cuda(),
    input_time.repeat_interleave(int(ceil(neg_sampling.amount))).cuda(),
    self.__batch_size,
)

Comment on lines +219 to +240
num_neg_per_pos = int(ceil(neg_sampling.amount))
seed_time = (
seed_time.view(1, -1).expand(num_neg_per_pos, -1).flatten().cuda()
)

# TODO modifiy the C API so this condition is impossible
if src_neg.numel() < num_neg:
num_gen = num_neg - src_neg.numel()
src_neg = torch.concat(
[
src_neg,
torch.randint(
0, src_neg.max(), (num_gen,), device="cuda", dtype=torch.int64
),
# For homogeneous graphs, input_type is None, so get the single node type
if graph_store.is_homogeneous:
node_type = list(graph_store._vertex_offsets.keys())[0]
node_offset = graph_store._vertex_offsets[node_type]
src_node_type = dst_node_type = node_type
src_node_offset = dst_node_offset = node_offset
else:
src_node_type = input_type[0]
dst_node_type = input_type[2]
src_node_offset = graph_store._vertex_offsets[src_node_type]
dst_node_offset = graph_store._vertex_offsets[dst_node_type]

src_node_time = node_time_func(src_node_type, src_neg - src_node_offset)
dst_node_time = node_time_func(dst_node_type, dst_neg - dst_node_offset)

target_samples = src_neg.numel()
valid_mask = (src_node_time <= seed_time) & (dst_node_time <= seed_time)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Size mismatch when amount is non-integer

seed_time is expanded to ceil(amount) * N_pos elements (line 221), but src_neg (from the earlier _call_plc_negative_sampling call) has num_neg = max(ceil(amount * N_pos), ceil(N_pos / batch_size)) elements.

When amount is non-integer, ceil(amount * N_pos)ceil(amount) * N_pos. For example, with amount=1.5 and N_pos=2: num_neg = ceil(3) = 3 but seed_time.numel() = ceil(1.5) * 2 = 4. The element-wise comparison on line 240 will then raise a broadcast/shape error:

valid_mask = (src_node_time <= seed_time) & (dst_node_time <= seed_time)

seed_time should be sized to match num_neg, not ceil(amount) * N_pos. One fix: use num_neg when generating the expanded seed_time, i.e. use seed_time.repeat_interleave(...) up to num_neg entries rather than ceil(amount) * N_pos.


if src_neg.numel() == 0:
# Generate subsample of pseudo-negative edges to avoid edge case where no negative edges are generated.
# In the next step, these will be used to choose the earlist occuring node for src/dst.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in comment

"earlist" → "earliest"

Suggested change
# In the next step, these will be used to choose the earlist occuring node for src/dst.
# In the next step, these will be used to choose the earliest occuring node for src/dst.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant