Skip to content

CTAN#338

Merged
Jacob-Chmura merged 31 commits intomainfrom
examples/ctan
Jan 21, 2026
Merged

CTAN#338
Jacob-Chmura merged 31 commits intomainfrom
examples/ctan

Conversation

@Jacob-Chmura
Copy link
Member

@Jacob-Chmura Jacob-Chmura commented Nov 19, 2025

Close #349

@Jacob-Chmura Jacob-Chmura self-assigned this Nov 19, 2025
Copy link
Contributor

@benjaminnNgo benjaminnNgo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I heard that you can run CTAN. Can you try to run their CTAN and our CTAN and compare the loss after 5-10 epochs? Try to use the same seed to see if they converge to losses that are close to each other. Let's narrow down whether the training or evaluation loop is different

Comment on lines 56 to 67
class LastAggregator(torch.nn.Module):
def forward(self, msg: Tensor, index: Tensor, t: Tensor, dim_size: int):
out = msg.new_zeros((dim_size, msg.size(-1)))

if index.numel() > 0:
scores = torch.full((dim_size, t.size(0)), float('-inf'), device=t.device)
scores[index, torch.arange(t.size(0), device=t.device)] = t.float()
argmax = scores.argmax(dim=1)
valid = scores.max(dim=1).values > float('-inf')
out[valid] = msg[argmax[valid]]

return out
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know that we got this from TGN, and we know TGN works. To me, this logic makes sense. But our code differs from CTAN this module. May be worth replacing this with torch_scatter as the last thing to try.

@Jacob-Chmura Jacob-Chmura marked this pull request as ready for review December 4, 2025 18:19
@Jacob-Chmura Jacob-Chmura marked this pull request as draft December 4, 2025 18:19
Comment on lines 61 to 159
class SimpleMemory(torch.nn.Module):
def __init__(
self, num_nodes: int, memory_dim: int, aggr_module: Callable, init_time: int = 0
) -> None:
super().__init__()

self.num_nodes = num_nodes
self.memory_dim = memory_dim
self.init_time = init_time
self.aggr_module = aggr_module

self.register_buffer('memory', torch.zeros(num_nodes, memory_dim))
self.register_buffer(
'last_update', torch.ones(self.num_nodes, dtype=torch.long) * init_time
)
self.register_buffer('_assoc', torch.empty(num_nodes, dtype=torch.long))

def update_state(self, src, pos_dst, t, src_emb, pos_dst_emb):
idx = torch.cat([src, pos_dst], dim=0)
_idx = idx.unique()
self._assoc[_idx] = torch.arange(_idx.size(0), device=_idx.device)

t = torch.cat([t, t], dim=0)
last_update = scatter(t, self._assoc[idx], 0, _idx.size(0), reduce='max')

emb = torch.cat([src_emb, pos_dst_emb], dim=0)
aggr = self.aggr_module(emb, self._assoc[idx], t, _idx.size(0))

self.last_update[_idx] = last_update
self.memory[_idx] = aggr.detach()

def reset_state(self):
zeros(self.memory)
ones(self.last_update)
self.last_update *= self.init_time

def detach(self):
self.memory.detach_()

def forward(self, n_id):
return self.memory[n_id], self.last_update[n_id]


class LastAggregator(torch.nn.Module):
def forward(self, msg: Tensor, index: Tensor, t: Tensor, dim_size: int):
_, argmax = scatter_max(t, index, dim=0, dim_size=dim_size)
out = msg.new_zeros((dim_size, msg.size(-1)))
mask = argmax < msg.size(0) # Filter items with at least one entry.
out[mask] = msg[argmax[mask]]
return out


class TGBLinkPredictor(torch.nn.Module):
def __init__(self, in_channels):
super().__init__()
self.lin_src = torch.nn.Linear(in_channels, in_channels)
self.lin_dst = torch.nn.Linear(in_channels, in_channels)
self.lin_final = torch.nn.Linear(in_channels, 1)

def forward(self, z_src, z_dst):
h = self.lin_src(z_src) + self.lin_dst(z_dst)
h = h.relu()
return self.lin_final(h).view(-1)


class CTAN(torch.nn.Module):
def __init__(
self,
edge_dim: int,
memory_dim: int,
time_dim: int,
node_dim: int,
num_iters: int = 1,
mean_delta_t: float = 0.0,
std_delta_t: float = 1.0,
epsilon: float = 0.1,
gamma: float = 0.1,
):
super().__init__()
self.mean_delta_t = mean_delta_t
self.std_delta_t = std_delta_t
self.time_enc = TimeEncoder(time_dim)
self.enc_x = nn.Linear(memory_dim + node_dim, memory_dim)

phi = TransformerConv(
memory_dim, memory_dim, edge_dim=edge_dim + time_dim, root_weight=False
)
self.aconv = AntiSymmetricConv(
memory_dim, phi, num_iters=num_iters, epsilon=epsilon, gamma=gamma
)

def forward(self, x, last_update, edge_index, t, msg):
rel_t = (last_update[edge_index[0]] - t).abs()
rel_t = ((rel_t - self.mean_delta_t) / self.std_delta_t).to(x.dtype)
enc_x = self.enc_x(x)
edge_attr = torch.cat([msg, self.time_enc(rel_t)], dim=-1)
z = self.aconv(enc_x, edge_index, edge_attr=edge_attr)
z = torch.tanh(z)
return z
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a direct port of the original as far as I can tell @benjaminnNgo

As written, I'm seeing about 0.53 validation mrr (better than what we were getting before at 0.3), but still less than the reported 0.65 tgb number. I'm not doing early stopping etc.

Copy link
Contributor

@benjaminnNgo benjaminnNgo Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you took the same CTAN as the authors' implementations (not changing torch_scatter)? This may suggest that we did something wrong when we refactored CTAN before. But it is good that we can get up to 0.53 now. I think we can try 2 things:

  1. You said that you did the grid search to search for the best configuration from the author's source code. Let's take the best config and train with the author's source code and see how many epochs (assume n) CTAN need to get the best validation results. Then go back to our implementation, and run for 1.5 * n epochs, we expect to see the best validation results achieves within 1.5 * n epochs should be close to 0.65. (Why 1.5 * n, just trying to consider randomness makes the optimization converge slightly different)
  2. If the first step happens in the way we expect, then refactor from this version (changing from torch_scatter to torch operations etc.) and make sure the performance doesn't change

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

~0.56 Test MRR @ 100 epochs with optimal hyperparams.



MERGE_OP = {
class LearnableSumMerge(nn.Module):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is distinct from doing something like z_src + z_dst as our merge op. Note, doing this simple aggregation could not reproducive the CTAN numbers we expect from the original implementation (which has the LearnableSumMerge decoder

@Jacob-Chmura Jacob-Chmura marked this pull request as ready for review January 20, 2026 22:41
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Down the line this should probably be unified with TGNMemory

@tgm-team tgm-team deleted a comment from codecov bot Jan 20, 2026
@codecov
Copy link

codecov bot commented Jan 20, 2026

Codecov Report

❌ Patch coverage is 98.66667% with 1 line in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
tgm/nn/encoder/ctan.py 98.21% 1 Missing ⚠️

📢 Thoughts on this report? Let us know!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had some log formatting warnings that went through to master

@github-actions
Copy link

Job: test_ctan_linkprop_pred_tgbl-wiki

Metric Value
materialize latency 0.001
execute_active_hooks latency 0.012
train latency 8.407
train peak_gpu_mb 543.592
train alloc_gpu_mb 531.662
eval latency 13.983
eval peak_gpu_mb 710.904
eval alloc_gpu_mb 465.136
Loss epoch 1 330.375
Validation mrr epoch 1 0.227
Test mrr epoch 1 0.215

Download raw logs

@shenyangHuang shenyangHuang self-requested a review January 21, 2026 14:23
@Jacob-Chmura Jacob-Chmura merged commit 7f0d2f0 into main Jan 21, 2026
7 checks passed
@Jacob-Chmura Jacob-Chmura deleted the examples/ctan branch January 21, 2026 15:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

CTAN Implementation

3 participants