Skip to content

Commit b1b24ac

Browse files
authored
GraphMixer Fix (#280)
* Update * upd * wip * simplify * upd * upd * upd * Fix critical decoder bug * upd * upd * upd * upd
1 parent 694c47f commit b1b24ac

File tree

3 files changed

+38
-75
lines changed

3 files changed

+38
-75
lines changed

examples/linkproppred/graphmixer.py

Lines changed: 36 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import argparse
2-
from collections import defaultdict, deque
3-
from typing import Any, Deque, Dict
2+
from collections import defaultdict
3+
from dataclasses import replace
44

55
import numpy as np
66
import torch
@@ -16,7 +16,7 @@
1616
RECIPE_TGB_LINK_PRED,
1717
)
1818
from tgm.data import DGData, DGDataLoader
19-
from tgm.hooks import RecencyNeighborHook, RecipeRegistry, StatefulHook
19+
from tgm.hooks import RecencyNeighborHook, RecipeRegistry, StatelessHook
2020
from tgm.nn import LinkPredictor, MLPMixer, Time2Vec
2121
from tgm.util.logging import enable_logging, log_gpu, log_latency, log_metric
2222
from tgm.util.seed import seed_everything
@@ -34,7 +34,7 @@
3434
parser.add_argument('--dropout', type=str, default=0.1, help='dropout rate')
3535
parser.add_argument('--n-nbrs', type=int, default=20, help='num sampled nbrs')
3636
parser.add_argument('--time-dim', type=int, default=100, help='time encoding dimension')
37-
parser.add_argument('--embed-dim', type=int, default=128, help='attention dimension')
37+
parser.add_argument('--embed-dim', type=int, default=128, help='embedding dimension')
3838
parser.add_argument(
3939
'--node-dim', type=int, default=100, help='node feat dimension if not provided'
4040
)
@@ -102,20 +102,24 @@ def forward(self, batch: DGBatch, node_feat: torch.Tensor) -> torch.Tensor:
102102
# Link Encoder
103103
edge_feat = batch.nbr_feats[0]
104104
nbr_time_feat = self.time_encoder(batch.times[0][:, None] - batch.nbr_times[0])
105-
nbr_time_feat[batch.nbr_nids[0] == PADDED_NODE_ID] = 0.0
106105
z_link = self.projection_layer(torch.cat([edge_feat, nbr_time_feat], dim=-1))
107106
for mixer in self.mlp_mixers:
108107
z_link = mixer(z_link)
109-
z_link = torch.mean(z_link, dim=1)
108+
109+
valid_nbrs_mask = batch.nbr_nids[0] != PADDED_NODE_ID
110+
z_link = z_link * valid_nbrs_mask.unsqueeze(-1)
111+
z_link = z_link.sum(dim=1) / valid_nbrs_mask.sum(dim=1, keepdim=True).clamp(
112+
min=1
113+
)
110114

111115
# Node Encoder
112-
time_gap_node_feats = node_feat[batch.time_gap_node_nids]
113-
mask = (batch.time_gap_node_nids != PADDED_NODE_ID).float()
114-
masked_feats = time_gap_node_feats * mask.unsqueeze(-1)
115-
nbr_count = mask.sum(dim=1, keepdim=True).clamp(min=1) # Mean over valid nbrs
116-
agg_feats = masked_feats.sum(dim=1) / nbr_count
117-
z_node = agg_feats + node_feat[torch.cat([batch.src, batch.dst, batch.neg])]
116+
num_nodes, feat_dim = len(batch.time_gap_nbrs), node_feat.shape[1]
117+
time_gap_feat = torch.zeros((num_nodes, feat_dim), device=node_feat.device)
118+
for i in range(num_nodes):
119+
if batch.time_gap_nbrs[i]:
120+
time_gap_feat[i] = node_feat[batch.time_gap_nbrs[i]].mean(dim=0)
118121

122+
z_node = time_gap_feat + node_feat[torch.cat([batch.src, batch.dst, batch.neg])]
119123
z = self.output_layer(torch.cat([z_link, z_node], dim=1))
120124
return z
121125

@@ -196,75 +200,35 @@ def eval(
196200
test_dg = DGraph(test_data, device=args.device)
197201

198202

199-
class GraphMixerHook(StatefulHook):
203+
class GraphMixerHook(StatelessHook):
200204
r"""Custom hook that gets 1-hop neighbors in a specific window.
201205
202206
If N(v_i, t_s, t_e) = nbrs of v_i from [t_s, t_e], then we materialize
203207
N(node_ids, t - TIME_GAP, t) for all seed nodes in a given batch.
204208
"""
205209

206210
requires = {'neg'}
207-
produces = {'time_gap_node_nids'}
211+
produces = {'time_gap_nbrs'}
208212

209213
def __init__(self, time_gap: int) -> None:
210-
self._num_nbrs = time_gap
211-
self._history: Dict[int, Deque[Any]] = defaultdict(
212-
lambda: deque(maxlen=self._num_nbrs)
213-
)
214-
self._device = torch.device('cpu')
215-
216-
def reset_state(self) -> None:
217-
self._history = defaultdict(lambda: deque(maxlen=self._num_nbrs))
214+
self._time_gap = time_gap
218215

219216
def __call__(self, dg: DGraph, batch: DGBatch) -> DGBatch:
220-
device = dg.device
221-
self._move_queues_to_device_if_needed(device) # No-op after first batch
222-
223-
batch.neg = batch.neg.to(device)
224-
225-
seed_nodes = torch.cat([batch.src, batch.dst, batch.neg])
226-
seed_times = torch.cat([batch.time.repeat(2), batch.neg_time])
227-
228-
batch.time_gap_node_nids = self._get_recency_neighbors(
229-
seed_nodes, seed_times, self._num_nbrs
230-
)
231-
232-
self._update(batch)
217+
# Construct a the time_gap slice
218+
time_gap_slice = replace(dg._slice)
219+
time_gap_slice.start_idx = max(dg._slice.end_idx - self._time_gap, 0)
220+
time_gap_slice.end_time = int(batch.time.min()) - 1
221+
time_gap_src, time_gap_dst, _ = dg._storage.get_edges(time_gap_slice)
222+
223+
nbr_index = defaultdict(list)
224+
for u, v in zip(time_gap_src.tolist(), time_gap_dst.tolist()):
225+
nbr_index[u].append(v)
226+
nbr_index[v].append(u) # undirected
227+
228+
seed_nodes = torch.cat([batch.src, batch.dst, batch.neg.to(dg.device)])
229+
batch.time_gap_nbrs = [nbr_index.get(nid, []) for nid in seed_nodes.tolist()] # type: ignore
233230
return batch
234231

235-
def _get_recency_neighbors(
236-
self, node_ids: torch.Tensor, query_times: torch.Tensor, k: int
237-
) -> torch.Tensor:
238-
num_nodes = node_ids.size(0)
239-
device = node_ids.device
240-
nbr_nids = torch.full(
241-
(num_nodes, k), PADDED_NODE_ID, dtype=torch.long, device=device
242-
)
243-
244-
for i in range(num_nodes):
245-
nid, qtime = int(node_ids[i]), int(query_times[i])
246-
history = self._history[nid]
247-
valid = [(nbr, t) for (nbr, t) in history if t < qtime]
248-
if not valid:
249-
continue
250-
valid = valid[-k:] # most recent k
251-
252-
nbr_nids[i, -len(valid) :] = torch.tensor(
253-
[x[0] for x in valid], dtype=torch.long, device=device
254-
)
255-
256-
return nbr_nids
257-
258-
def _update(self, batch: DGBatch) -> None:
259-
src, dst, time = batch.src.tolist(), batch.dst.tolist(), batch.time.tolist()
260-
for s, d, t in zip(src, dst, time):
261-
self._history[s].append((d, t))
262-
self._history[d].append((s, t)) # undirected
263-
264-
def _move_queues_to_device_if_needed(self, device: torch.device) -> None:
265-
if device != self._device:
266-
self._device = device
267-
268232

269233
hm = RecipeRegistry.build(
270234
RECIPE_TGB_LINK_PRED, dataset_name=args.dataset, train_dg=train_dg
@@ -288,19 +252,19 @@ def _move_queues_to_device_if_needed(self, device: torch.device) -> None:
288252
if train_dg.static_node_feats is not None:
289253
static_node_feats = train_dg.static_node_feats
290254
else:
291-
static_node_feats = torch.randn(
255+
static_node_feats = torch.zeros(
292256
(test_dg.num_nodes, args.node_dim), device=args.device
293257
)
294258

295259
encoder = GraphMixerEncoder(
296-
embed_dim=args.embed_dim,
260+
node_dim=static_node_feats.shape[1],
261+
edge_dim=train_dg.edge_feats_dim,
297262
time_dim=args.time_dim,
263+
embed_dim=args.embed_dim,
298264
num_tokens=args.n_nbrs,
299265
token_dim_expansion=float(args.token_dim_expansion),
300266
channel_dim_expansion=float(args.channel_dim_expansion),
301267
dropout=float(args.dropout),
302-
node_dim=static_node_feats.shape[1],
303-
edge_dim=train_dg.edge_feats_dim | args.embed_dim,
304268
).to(args.device)
305269
decoder = LinkPredictor(node_dim=args.embed_dim, hidden_dim=args.embed_dim).to(
306270
args.device

test/unit/test_nn/test_linkdecoder.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@ def test_output(edge_factory):
3232

3333
assert not torch.isnan(out).any()
3434
assert len(decoder.model) == 9 # 5 layers + 4 ReLU
35-
assert out.shape[0] == 200
36-
assert out.shape[1] == 1
35+
assert list(out.shape) == [200]
3736

3837
# check the first layer
3938
assert decoder.model[0].in_features == 128 * 2 # concat 2 nodes embeddings

tgm/nn/decoder/linkproppred.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,4 @@ def forward(self, z_src: torch.Tensor, z_dst: torch.Tensor) -> torch.Tensor:
6161
z_src (torch.Tensor): embedding of src node
6262
z_dst (torch.Tensor): embedding of dst node
6363
"""
64-
return self.model(self.merge_op(z_src, z_dst))
64+
return self.model(self.merge_op(z_src, z_dst)).view(-1)

0 commit comments

Comments
 (0)