-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtpc.py
More file actions
43 lines (33 loc) · 1.29 KB
/
tpc.py
File metadata and controls
43 lines (33 loc) · 1.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import numpy as np
import torch
from sklearn.decomposition import PCA
from beliefprobing import BeliefProbe
from beliefprobing.generate import GenOut
from beliefprobing.probes.beliefprobe import Direction
@BeliefProbe.register('pc_sklearn')
class PrincipalComponent(BeliefProbe):
def __init__(self, style='concat', component_i: int = 0, svd_data: bool = False):
super().__init__(svd_data)
self.style = style
self.pca = PCA(n_components=component_i+1)
self.component_i = component_i
def do_train(self, gen_out: GenOut) -> float:
hs, _, y = gen_out
neg_hs, pos_hs = hs[0], hs[1]
if self.style == 'concat':
both_hs = np.concatenate([neg_hs, pos_hs])
elif self.style == 'subtract':
both_hs = neg_hs - pos_hs
elif self.style == 'pos-only':
both_hs = pos_hs
elif self.style == 'neg-only':
both_hs = neg_hs
else:
raise ValueError
self.pca.fit(both_hs)
def _get_direction(self) -> Direction:
return torch.tensor(self.pca.components_[self.component_i]).squeeze().float().cpu()
def _belief_direction(self) -> Direction:
return self._get_direction()
def _normal_direction(self) -> Direction:
return self._get_direction()