Backport Temporal Negative Sampling Update/Fix to 26.02#424
Backport Temporal Negative Sampling Update/Fix to 26.02#424alexbarghi-nv wants to merge 1 commit intorapidsai:release/26.02from
Conversation
Greptile SummaryThis PR backports temporal negative sampling support to release/26.02. The core changes implement node-time-aware negative edge generation in Key issues found:
Confidence Score: 2/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant LNL as LinkNeighborLoader
participant BS as BaseSampler
participant NS as neg_sample()
participant PLC as _call_plc_negative_sampling()
participant GS as GraphStore
participant DS as DistributedSampler
LNL->>GS: _set_time_attr(feature_store, time_attr)
LNL->>BS: sample_from_edges(index)
BS->>GS: _get_ntime_func()
GS-->>BS: node_time_func (lambda)
BS->>NS: neg_sample(graph_store, src, dst, input_type, batch_size, neg_sampling, index.time, node_time_func)
NS->>PLC: _call_plc_negative_sampling(num_neg, vertices, weights)
PLC-->>NS: src_neg, dst_neg (raw)
Note over NS: Filter by node_time <= seed_time<br/>(retry up to 5x)
NS-->>BS: src_neg, dst_neg (temporally valid)
BS->>BS: neg_cat(src, src_neg, batch_size)
BS->>BS: neg_cat(dst, dst_neg, batch_size)
Note over BS: ⚠️ input_time neg_cat args swapped
BS->>DS: sample_from_edges(seeds, input_time, batch_size)
DS->>DS: unique_consecutive with leftover_seeds_unique_mask fix
DS-->>BS: sampled batches
BS-->>LNL: HeteroSamplerOutput / SamplerOutput
|
| if node_time is not None and input_time is not None: | ||
| input_time, _ = neg_cat( | ||
| input_time.repeat_interleave(int(ceil(neg_sampling.amount))).cuda(), | ||
| input_time.cuda(), | ||
| self.__batch_size, | ||
| ) |
There was a problem hiding this comment.
neg_cat arguments are swapped for input_time
neg_cat(seed_pos, seed_neg, batch_size) expects positives as the first argument. Here, the expanded version (N_pos × ceil(amount) elements) is passed first, but the positives should match src/dst (which have N_pos elements in the first arg). This swap causes two problems:
num_batchesinsideneg_catis computed from the first arg's size, soinput_timeends up with a different batch structure thansrc/dst.- When
amountis non-integer, the total element count of the resultinginput_timewon't match the combinedsrc/dstsize.
The correct call should be:
input_time, _ = neg_cat(
input_time.cuda(),
input_time.repeat_interleave(int(ceil(neg_sampling.amount))).cuda(),
self.__batch_size,
)| num_neg_per_pos = int(ceil(neg_sampling.amount)) | ||
| seed_time = ( | ||
| seed_time.view(1, -1).expand(num_neg_per_pos, -1).flatten().cuda() | ||
| ) | ||
|
|
||
| # TODO modifiy the C API so this condition is impossible | ||
| if src_neg.numel() < num_neg: | ||
| num_gen = num_neg - src_neg.numel() | ||
| src_neg = torch.concat( | ||
| [ | ||
| src_neg, | ||
| torch.randint( | ||
| 0, src_neg.max(), (num_gen,), device="cuda", dtype=torch.int64 | ||
| ), | ||
| # For homogeneous graphs, input_type is None, so get the single node type | ||
| if graph_store.is_homogeneous: | ||
| node_type = list(graph_store._vertex_offsets.keys())[0] | ||
| node_offset = graph_store._vertex_offsets[node_type] | ||
| src_node_type = dst_node_type = node_type | ||
| src_node_offset = dst_node_offset = node_offset | ||
| else: | ||
| src_node_type = input_type[0] | ||
| dst_node_type = input_type[2] | ||
| src_node_offset = graph_store._vertex_offsets[src_node_type] | ||
| dst_node_offset = graph_store._vertex_offsets[dst_node_type] | ||
|
|
||
| src_node_time = node_time_func(src_node_type, src_neg - src_node_offset) | ||
| dst_node_time = node_time_func(dst_node_type, dst_neg - dst_node_offset) | ||
|
|
||
| target_samples = src_neg.numel() | ||
| valid_mask = (src_node_time <= seed_time) & (dst_node_time <= seed_time) |
There was a problem hiding this comment.
Size mismatch when amount is non-integer
seed_time is expanded to ceil(amount) * N_pos elements (line 221), but src_neg (from the earlier _call_plc_negative_sampling call) has num_neg = max(ceil(amount * N_pos), ceil(N_pos / batch_size)) elements.
When amount is non-integer, ceil(amount * N_pos) ≠ ceil(amount) * N_pos. For example, with amount=1.5 and N_pos=2: num_neg = ceil(3) = 3 but seed_time.numel() = ceil(1.5) * 2 = 4. The element-wise comparison on line 240 will then raise a broadcast/shape error:
valid_mask = (src_node_time <= seed_time) & (dst_node_time <= seed_time)
seed_time should be sized to match num_neg, not ceil(amount) * N_pos. One fix: use num_neg when generating the expanded seed_time, i.e. use seed_time.repeat_interleave(...) up to num_neg entries rather than ceil(amount) * N_pos.
|
|
||
| if src_neg.numel() == 0: | ||
| # Generate subsample of pseudo-negative edges to avoid edge case where no negative edges are generated. | ||
| # In the next step, these will be used to choose the earlist occuring node for src/dst. |
There was a problem hiding this comment.
Typo in comment
"earlist" → "earliest"
| # In the next step, these will be used to choose the earlist occuring node for src/dst. | |
| # In the next step, these will be used to choose the earliest occuring node for src/dst. |
Backports the fixes and changes required for temporal negative sampling to release 26.02.