Conversation
benjaminnNgo
left a comment
There was a problem hiding this comment.
I heard that you can run CTAN. Can you try to run their CTAN and our CTAN and compare the loss after 5-10 epochs? Try to use the same seed to see if they converge to losses that are close to each other. Let's narrow down whether the training or evaluation loop is different
examples/linkproppred/ctan.py
Outdated
| class LastAggregator(torch.nn.Module): | ||
| def forward(self, msg: Tensor, index: Tensor, t: Tensor, dim_size: int): | ||
| out = msg.new_zeros((dim_size, msg.size(-1))) | ||
|
|
||
| if index.numel() > 0: | ||
| scores = torch.full((dim_size, t.size(0)), float('-inf'), device=t.device) | ||
| scores[index, torch.arange(t.size(0), device=t.device)] = t.float() | ||
| argmax = scores.argmax(dim=1) | ||
| valid = scores.max(dim=1).values > float('-inf') | ||
| out[valid] = msg[argmax[valid]] | ||
|
|
||
| return out |
There was a problem hiding this comment.
I know that we got this from TGN, and we know TGN works. To me, this logic makes sense. But our code differs from CTAN this module. May be worth replacing this with torch_scatter as the last thing to try.
examples/linkproppred/ctan.py
Outdated
| class SimpleMemory(torch.nn.Module): | ||
| def __init__( | ||
| self, num_nodes: int, memory_dim: int, aggr_module: Callable, init_time: int = 0 | ||
| ) -> None: | ||
| super().__init__() | ||
|
|
||
| self.num_nodes = num_nodes | ||
| self.memory_dim = memory_dim | ||
| self.init_time = init_time | ||
| self.aggr_module = aggr_module | ||
|
|
||
| self.register_buffer('memory', torch.zeros(num_nodes, memory_dim)) | ||
| self.register_buffer( | ||
| 'last_update', torch.ones(self.num_nodes, dtype=torch.long) * init_time | ||
| ) | ||
| self.register_buffer('_assoc', torch.empty(num_nodes, dtype=torch.long)) | ||
|
|
||
| def update_state(self, src, pos_dst, t, src_emb, pos_dst_emb): | ||
| idx = torch.cat([src, pos_dst], dim=0) | ||
| _idx = idx.unique() | ||
| self._assoc[_idx] = torch.arange(_idx.size(0), device=_idx.device) | ||
|
|
||
| t = torch.cat([t, t], dim=0) | ||
| last_update = scatter(t, self._assoc[idx], 0, _idx.size(0), reduce='max') | ||
|
|
||
| emb = torch.cat([src_emb, pos_dst_emb], dim=0) | ||
| aggr = self.aggr_module(emb, self._assoc[idx], t, _idx.size(0)) | ||
|
|
||
| self.last_update[_idx] = last_update | ||
| self.memory[_idx] = aggr.detach() | ||
|
|
||
| def reset_state(self): | ||
| zeros(self.memory) | ||
| ones(self.last_update) | ||
| self.last_update *= self.init_time | ||
|
|
||
| def detach(self): | ||
| self.memory.detach_() | ||
|
|
||
| def forward(self, n_id): | ||
| return self.memory[n_id], self.last_update[n_id] | ||
|
|
||
|
|
||
| class LastAggregator(torch.nn.Module): | ||
| def forward(self, msg: Tensor, index: Tensor, t: Tensor, dim_size: int): | ||
| _, argmax = scatter_max(t, index, dim=0, dim_size=dim_size) | ||
| out = msg.new_zeros((dim_size, msg.size(-1))) | ||
| mask = argmax < msg.size(0) # Filter items with at least one entry. | ||
| out[mask] = msg[argmax[mask]] | ||
| return out | ||
|
|
||
|
|
||
| class TGBLinkPredictor(torch.nn.Module): | ||
| def __init__(self, in_channels): | ||
| super().__init__() | ||
| self.lin_src = torch.nn.Linear(in_channels, in_channels) | ||
| self.lin_dst = torch.nn.Linear(in_channels, in_channels) | ||
| self.lin_final = torch.nn.Linear(in_channels, 1) | ||
|
|
||
| def forward(self, z_src, z_dst): | ||
| h = self.lin_src(z_src) + self.lin_dst(z_dst) | ||
| h = h.relu() | ||
| return self.lin_final(h).view(-1) | ||
|
|
||
|
|
||
| class CTAN(torch.nn.Module): | ||
| def __init__( | ||
| self, | ||
| edge_dim: int, | ||
| memory_dim: int, | ||
| time_dim: int, | ||
| node_dim: int, | ||
| num_iters: int = 1, | ||
| mean_delta_t: float = 0.0, | ||
| std_delta_t: float = 1.0, | ||
| epsilon: float = 0.1, | ||
| gamma: float = 0.1, | ||
| ): | ||
| super().__init__() | ||
| self.mean_delta_t = mean_delta_t | ||
| self.std_delta_t = std_delta_t | ||
| self.time_enc = TimeEncoder(time_dim) | ||
| self.enc_x = nn.Linear(memory_dim + node_dim, memory_dim) | ||
|
|
||
| phi = TransformerConv( | ||
| memory_dim, memory_dim, edge_dim=edge_dim + time_dim, root_weight=False | ||
| ) | ||
| self.aconv = AntiSymmetricConv( | ||
| memory_dim, phi, num_iters=num_iters, epsilon=epsilon, gamma=gamma | ||
| ) | ||
|
|
||
| def forward(self, x, last_update, edge_index, t, msg): | ||
| rel_t = (last_update[edge_index[0]] - t).abs() | ||
| rel_t = ((rel_t - self.mean_delta_t) / self.std_delta_t).to(x.dtype) | ||
| enc_x = self.enc_x(x) | ||
| edge_attr = torch.cat([msg, self.time_enc(rel_t)], dim=-1) | ||
| z = self.aconv(enc_x, edge_index, edge_attr=edge_attr) | ||
| z = torch.tanh(z) | ||
| return z |
There was a problem hiding this comment.
This is a direct port of the original as far as I can tell @benjaminnNgo
As written, I'm seeing about 0.53 validation mrr (better than what we were getting before at 0.3), but still less than the reported 0.65 tgb number. I'm not doing early stopping etc.
There was a problem hiding this comment.
So you took the same CTAN as the authors' implementations (not changing torch_scatter)? This may suggest that we did something wrong when we refactored CTAN before. But it is good that we can get up to 0.53 now. I think we can try 2 things:
- You said that you did the grid search to search for the best configuration from the author's source code. Let's take the best config and train with the author's source code and see how many epochs (assume
n)CTANneed to get the best validation results. Then go back to our implementation, and run for 1.5 *nepochs, we expect to see the best validation results achieves within 1.5 *nepochs should be close to0.65. (Why 1.5 *n, just trying to consider randomness makes the optimization converge slightly different) - If the first step happens in the way we expect, then refactor from this version (changing from torch_scatter to torch operations etc.) and make sure the performance doesn't change
There was a problem hiding this comment.
~0.56 Test MRR @ 100 epochs with optimal hyperparams.
|
|
||
|
|
||
| MERGE_OP = { | ||
| class LearnableSumMerge(nn.Module): |
There was a problem hiding this comment.
This is distinct from doing something like z_src + z_dst as our merge op. Note, doing this simple aggregation could not reproducive the CTAN numbers we expect from the original implementation (which has the LearnableSumMerge decoder
There was a problem hiding this comment.
Down the line this should probably be unified with TGNMemory
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
There was a problem hiding this comment.
Had some log formatting warnings that went through to master
Job: test_ctan_linkprop_pred_tgbl-wiki
|
Close #349