From 6dc2ccb199d9099d3e92d10f2c680671f5726fc0 Mon Sep 17 00:00:00 2001 From: paras Date: Sun, 10 Mar 2024 11:16:12 +0530 Subject: [PATCH] Optimization: Moving iteration on CPU to slice in tensor. Speed up the execution by 3x Optimization: Moving iteration on CPU to slice in tensor. Speed up the execution by little more than 3x. i.e. it will take 1 hr to train if earlier it was taking 3 hr to train. --- torchfm/layer.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/torchfm/layer.py b/torchfm/layer.py index ef5d12a..c9b0334 100644 --- a/torchfm/layer.py +++ b/torchfm/layer.py @@ -43,21 +43,27 @@ def __init__(self, field_dims, embed_dim): self.embeddings = torch.nn.ModuleList([ torch.nn.Embedding(sum(field_dims), embed_dim) for _ in range(self.num_fields) ]) - self.offsets = np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.long) + self.offsets = np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.int64) for embedding in self.embeddings: torch.nn.init.xavier_uniform_(embedding.weight.data) + self.l1 = [] + self.l2 = [] + for i in range(self.num_fields - 1): + for j in range(i + 1, self.num_fields): + self.l1.append(i) + self.l2.append(j) def forward(self, x): """ :param x: Long tensor of size ``(batch_size, num_fields)`` """ + x = x + x.new_tensor(self.offsets).unsqueeze(0) xs = [self.embeddings[i](x) for i in range(self.num_fields)] - ix = list() - for i in range(self.num_fields - 1): - for j in range(i + 1, self.num_fields): - ix.append(xs[j][:, i] * xs[i][:, j]) - ix = torch.stack(ix, dim=1) + xs = torch.stack(xs, dim = 1) + x1 = xs[:,self.l1, self.l2] + x2 = xs[:,self.l2,self.l1] + ix = x1 * x2 return ix