@@ -36,10 +36,7 @@ def __init__(self, model_name, nentity, nrelation, hidden_dim, gamma,
3636 a = - self .embedding_range .item (),
3737 b = self .embedding_range .item ())
3838
39- if model_name == 'pRotatE' :
40- self .modulus = nn .Parameter (torch .Tensor ([[0.5 * self .embedding_range .item ()]]))
41-
42- if model_name not in ['TransE' , 'DistMult' , 'ComplEx' , 'RotatE' , 'pRotatE' ]:
39+ if model_name not in ['ComplEx' , 'RotatE' ]:
4340 raise ValueError ('model {} not supported' .format (model_name ))
4441
4542 def forward (self , sample , mode = 'single' ):
@@ -111,8 +108,7 @@ def forward(self, sample, mode='single'):
111108
112109 model_func = {
113110 'ComplEx' : self .ComplEx ,
114- 'RotatE' : self .RotatE ,
115- 'pRotatE' : self .pRotatE
111+ 'RotatE' : self .RotatE
116112 }
117113
118114 if self .model_name in model_func :
@@ -166,23 +162,6 @@ def RotatE(self, head, relation, tail, mode):
166162 score = self .gamma .item () - score .sum (dim = 2 )
167163 return score
168164
169- def pRotatE (self , head , relation , tail , mode ):
170-
171- phase_head = head / (self .embedding_range .item ()/ math .pi )
172- phase_relation = relation / (self .embedding_range .item ()/ math .pi )
173- phase_tail = tail / (self .embedding_range .item ()/ math .pi )
174-
175- if mode == 'head-batch' :
176- score = phase_head + (phase_relation - phase_tail )
177- else :
178- score = (phase_head + phase_relation ) - phase_tail
179-
180- score = torch .sin (score )
181- score = torch .abs (score )
182-
183- score = self .gamma .item () - score .sum (dim = 2 ) * self .modulus
184- return score
185-
186165 @staticmethod
187166 def train_step (model , optimizer , train_iterator , args ):
188167 model .train ()
0 commit comments