11import argparse
2- from collections import defaultdict , deque
3- from typing import Any , Deque , Dict
2+ from collections import defaultdict
3+ from dataclasses import replace
44
55import numpy as np
66import torch
1616 RECIPE_TGB_LINK_PRED ,
1717)
1818from tgm .data import DGData , DGDataLoader
19- from tgm .hooks import RecencyNeighborHook , RecipeRegistry , StatefulHook
19+ from tgm .hooks import RecencyNeighborHook , RecipeRegistry , StatelessHook
2020from tgm .nn import LinkPredictor , MLPMixer , Time2Vec
2121from tgm .util .logging import enable_logging , log_gpu , log_latency , log_metric
2222from tgm .util .seed import seed_everything
3434parser .add_argument ('--dropout' , type = str , default = 0.1 , help = 'dropout rate' )
3535parser .add_argument ('--n-nbrs' , type = int , default = 20 , help = 'num sampled nbrs' )
3636parser .add_argument ('--time-dim' , type = int , default = 100 , help = 'time encoding dimension' )
37- parser .add_argument ('--embed-dim' , type = int , default = 128 , help = 'attention dimension' )
37+ parser .add_argument ('--embed-dim' , type = int , default = 128 , help = 'embedding dimension' )
3838parser .add_argument (
3939 '--node-dim' , type = int , default = 100 , help = 'node feat dimension if not provided'
4040)
@@ -102,20 +102,24 @@ def forward(self, batch: DGBatch, node_feat: torch.Tensor) -> torch.Tensor:
102102 # Link Encoder
103103 edge_feat = batch .nbr_feats [0 ]
104104 nbr_time_feat = self .time_encoder (batch .times [0 ][:, None ] - batch .nbr_times [0 ])
105- nbr_time_feat [batch .nbr_nids [0 ] == PADDED_NODE_ID ] = 0.0
106105 z_link = self .projection_layer (torch .cat ([edge_feat , nbr_time_feat ], dim = - 1 ))
107106 for mixer in self .mlp_mixers :
108107 z_link = mixer (z_link )
109- z_link = torch .mean (z_link , dim = 1 )
108+
109+ valid_nbrs_mask = batch .nbr_nids [0 ] != PADDED_NODE_ID
110+ z_link = z_link * valid_nbrs_mask .unsqueeze (- 1 )
111+ z_link = z_link .sum (dim = 1 ) / valid_nbrs_mask .sum (dim = 1 , keepdim = True ).clamp (
112+ min = 1
113+ )
110114
111115 # Node Encoder
112- time_gap_node_feats = node_feat [batch .time_gap_node_nids ]
113- mask = (batch .time_gap_node_nids != PADDED_NODE_ID ).float ()
114- masked_feats = time_gap_node_feats * mask .unsqueeze (- 1 )
115- nbr_count = mask .sum (dim = 1 , keepdim = True ).clamp (min = 1 ) # Mean over valid nbrs
116- agg_feats = masked_feats .sum (dim = 1 ) / nbr_count
117- z_node = agg_feats + node_feat [torch .cat ([batch .src , batch .dst , batch .neg ])]
116+ num_nodes , feat_dim = len (batch .time_gap_nbrs ), node_feat .shape [1 ]
117+ time_gap_feat = torch .zeros ((num_nodes , feat_dim ), device = node_feat .device )
118+ for i in range (num_nodes ):
119+ if batch .time_gap_nbrs [i ]:
120+ time_gap_feat [i ] = node_feat [batch .time_gap_nbrs [i ]].mean (dim = 0 )
118121
122+ z_node = time_gap_feat + node_feat [torch .cat ([batch .src , batch .dst , batch .neg ])]
119123 z = self .output_layer (torch .cat ([z_link , z_node ], dim = 1 ))
120124 return z
121125
@@ -196,75 +200,35 @@ def eval(
196200test_dg = DGraph (test_data , device = args .device )
197201
198202
199- class GraphMixerHook (StatefulHook ):
203+ class GraphMixerHook (StatelessHook ):
200204 r"""Custom hook that gets 1-hop neighbors in a specific window.
201205
202206 If N(v_i, t_s, t_e) = nbrs of v_i from [t_s, t_e], then we materialize
203207 N(node_ids, t - TIME_GAP, t) for all seed nodes in a given batch.
204208 """
205209
206210 requires = {'neg' }
207- produces = {'time_gap_node_nids ' }
211+ produces = {'time_gap_nbrs ' }
208212
209213 def __init__ (self , time_gap : int ) -> None :
210- self ._num_nbrs = time_gap
211- self ._history : Dict [int , Deque [Any ]] = defaultdict (
212- lambda : deque (maxlen = self ._num_nbrs )
213- )
214- self ._device = torch .device ('cpu' )
215-
216- def reset_state (self ) -> None :
217- self ._history = defaultdict (lambda : deque (maxlen = self ._num_nbrs ))
214+ self ._time_gap = time_gap
218215
219216 def __call__ (self , dg : DGraph , batch : DGBatch ) -> DGBatch :
220- device = dg . device
221- self . _move_queues_to_device_if_needed ( device ) # No-op after first batch
222-
223- batch . neg = batch .neg . to ( device )
224-
225- seed_nodes = torch . cat ([ batch . src , batch . dst , batch . neg ])
226- seed_times = torch . cat ([ batch . time . repeat ( 2 ), batch . neg_time ] )
227-
228- batch . time_gap_node_nids = self . _get_recency_neighbors (
229- seed_nodes , seed_times , self . _num_nbrs
230- )
231-
232- self . _update ( batch )
217+ # Construct a the time_gap slice
218+ time_gap_slice = replace ( dg . _slice )
219+ time_gap_slice . start_idx = max ( dg . _slice . end_idx - self . _time_gap , 0 )
220+ time_gap_slice . end_time = int ( batch .time . min ()) - 1
221+ time_gap_src , time_gap_dst , _ = dg . _storage . get_edges ( time_gap_slice )
222+
223+ nbr_index = defaultdict ( list )
224+ for u , v in zip ( time_gap_src . tolist (), time_gap_dst . tolist ()):
225+ nbr_index [ u ]. append ( v )
226+ nbr_index [ v ]. append ( u ) # undirected
227+
228+ seed_nodes = torch . cat ([ batch . src , batch . dst , batch . neg . to ( dg . device )])
229+ batch . time_gap_nbrs = [ nbr_index . get ( nid , []) for nid in seed_nodes . tolist ()] # type: ignore
233230 return batch
234231
235- def _get_recency_neighbors (
236- self , node_ids : torch .Tensor , query_times : torch .Tensor , k : int
237- ) -> torch .Tensor :
238- num_nodes = node_ids .size (0 )
239- device = node_ids .device
240- nbr_nids = torch .full (
241- (num_nodes , k ), PADDED_NODE_ID , dtype = torch .long , device = device
242- )
243-
244- for i in range (num_nodes ):
245- nid , qtime = int (node_ids [i ]), int (query_times [i ])
246- history = self ._history [nid ]
247- valid = [(nbr , t ) for (nbr , t ) in history if t < qtime ]
248- if not valid :
249- continue
250- valid = valid [- k :] # most recent k
251-
252- nbr_nids [i , - len (valid ) :] = torch .tensor (
253- [x [0 ] for x in valid ], dtype = torch .long , device = device
254- )
255-
256- return nbr_nids
257-
258- def _update (self , batch : DGBatch ) -> None :
259- src , dst , time = batch .src .tolist (), batch .dst .tolist (), batch .time .tolist ()
260- for s , d , t in zip (src , dst , time ):
261- self ._history [s ].append ((d , t ))
262- self ._history [d ].append ((s , t )) # undirected
263-
264- def _move_queues_to_device_if_needed (self , device : torch .device ) -> None :
265- if device != self ._device :
266- self ._device = device
267-
268232
269233hm = RecipeRegistry .build (
270234 RECIPE_TGB_LINK_PRED , dataset_name = args .dataset , train_dg = train_dg
@@ -288,19 +252,19 @@ def _move_queues_to_device_if_needed(self, device: torch.device) -> None:
288252if train_dg .static_node_feats is not None :
289253 static_node_feats = train_dg .static_node_feats
290254else :
291- static_node_feats = torch .randn (
255+ static_node_feats = torch .zeros (
292256 (test_dg .num_nodes , args .node_dim ), device = args .device
293257 )
294258
295259encoder = GraphMixerEncoder (
296- embed_dim = args .embed_dim ,
260+ node_dim = static_node_feats .shape [1 ],
261+ edge_dim = train_dg .edge_feats_dim ,
297262 time_dim = args .time_dim ,
263+ embed_dim = args .embed_dim ,
298264 num_tokens = args .n_nbrs ,
299265 token_dim_expansion = float (args .token_dim_expansion ),
300266 channel_dim_expansion = float (args .channel_dim_expansion ),
301267 dropout = float (args .dropout ),
302- node_dim = static_node_feats .shape [1 ],
303- edge_dim = train_dg .edge_feats_dim | args .embed_dim ,
304268).to (args .device )
305269decoder = LinkPredictor (node_dim = args .embed_dim , hidden_dim = args .embed_dim ).to (
306270 args .device
0 commit comments