Skip to content

Commit 413d3ad

Browse files
authored
Merge pull request #2 from tomgrek/node-attrs
Remove phased rotate
2 parents 230d5fa + b1ad484 commit 413d3ad

File tree

1 file changed

+2
-23
lines changed

1 file changed

+2
-23
lines changed

nn/rotate.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,7 @@ def __init__(self, model_name, nentity, nrelation, hidden_dim, gamma,
3636
a=-self.embedding_range.item(),
3737
b=self.embedding_range.item())
3838

39-
if model_name == 'pRotatE':
40-
self.modulus = nn.Parameter(torch.Tensor([[0.5 * self.embedding_range.item()]]))
41-
42-
if model_name not in ['TransE', 'DistMult', 'ComplEx', 'RotatE', 'pRotatE']:
39+
if model_name not in ['ComplEx', 'RotatE']:
4340
raise ValueError('model {} not supported'.format(model_name))
4441

4542
def forward(self, sample, mode='single'):
@@ -111,8 +108,7 @@ def forward(self, sample, mode='single'):
111108

112109
model_func = {
113110
'ComplEx': self.ComplEx,
114-
'RotatE': self.RotatE,
115-
'pRotatE': self.pRotatE
111+
'RotatE': self.RotatE
116112
}
117113

118114
if self.model_name in model_func:
@@ -166,23 +162,6 @@ def RotatE(self, head, relation, tail, mode):
166162
score = self.gamma.item() - score.sum(dim=2)
167163
return score
168164

169-
def pRotatE(self, head, relation, tail, mode):
170-
171-
phase_head = head/(self.embedding_range.item()/math.pi)
172-
phase_relation = relation/(self.embedding_range.item()/math.pi)
173-
phase_tail = tail/(self.embedding_range.item()/math.pi)
174-
175-
if mode == 'head-batch':
176-
score = phase_head + (phase_relation - phase_tail)
177-
else:
178-
score = (phase_head + phase_relation) - phase_tail
179-
180-
score = torch.sin(score)
181-
score = torch.abs(score)
182-
183-
score = self.gamma.item() - score.sum(dim = 2) * self.modulus
184-
return score
185-
186165
@staticmethod
187166
def train_step(model, optimizer, train_iterator, args):
188167
model.train()

0 commit comments

Comments
 (0)